Как работать с результатом org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs()

Я запускаю posenet (который является CNN) на Android с tflite. Модель имеет несколько выходных массивов со следующими размерностями: 1x14x14x17, 1x14x14x34, 1x14x14x32, 1x14x14x32

Поэтому запуск интерпретатора java tflite с

import org.tensorflow.lite.Interpreter;
Interpreter tflite;
...
tflite.runForMultipleInputsOutputs(inputs,outputs)

я могу получить доступ к четырем выходным тензорам с помощью tflite.getOutputTensor(i) или с outputs.get(i) (с i эл. [0,3]), поскольку outputs представляет собой HashMap заполненный java.nio.HeapByteBuffer объектами.

Как я могу преобразовать эти выходные данные или тензоры tflite в многомерные массивы Java (что-то вроде float[][][][];), чтобы иметь возможность выполнять над ними математические вычисления?

Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
Как вычислять биты и понимать побитовые операторы в Java - объяснение с примерами
Как вычислять биты и понимать побитовые операторы в Java - объяснение с примерами
В компьютерном программировании биты играют важнейшую роль в представлении и манипулировании данными на двоичном уровне. Побитовые операции...
Поднятие тревоги для долго выполняющихся методов в Spring Boot
Поднятие тревоги для долго выполняющихся методов в Spring Boot
Приходилось ли вам сталкиваться с требованиями, в которых вас могли попросить поднять тревогу или выдать ошибку, когда метод Java занимает больше...
Полный курс Java для разработчиков веб-сайтов и приложений
Полный курс Java для разработчиков веб-сайтов и приложений
Получите сертификат Java Web и Application Developer, используя наш курс.
4
0
2 408
3
Перейти к ответу Данный вопрос помечен как решенный

Ответы 3

Ответ принят как подходящий

Определение выходных данных, подобных следующему, позволяет вам работать с собственными массивами Java, чего я и хотел:

out1 = new float[1][14][14][17];
out2 = new float[1][14][14][34];
out3 = new float[1][14][14][32];
out4 = new float[1][14][14][32];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, out1);
outputs.put(1, out2);
outputs.put(2, out3);
outputs.put(3, out4);
// The shape of *1* output's tensor
int[] OutputShape;
// The type of the *1* output's tensor
DataType OutputDataType;
// The multi-tensor ready storage
outputProbabilityBuffers = new HashMap<>();

ByteBuffer x;
// For each model's tensors (there are getOutputTensorCount() of them for this tflite model)
for (int i = 0; i < tflite.getOutputTensorCount(); i++) {
    OutputShape = tflite.getOutputTensor(i).shape();
    OutputDataType = tflite.getOutputTensor(i).dataType();
    x = TensorBuffer.createFixedSize(OutputShape, OutputDataType).getBuffer();
    outputProbabilityBuffers.put(i, x);
    LOGGER.d("Created a buffer of %d bytes for tensor %d.", x.limit(), i);
}

LOGGER.d("Created a tflite output of %d output tensors.", outputProbabilityBuffers.size());

Пример вывода:

Classifier: Created a buffer of 11264 bytes for tensor 0.
Classifier: Created a buffer of 11264 bytes for tensor 1.
Classifier: Created a buffer of 4 bytes for tensor 2.
Classifier: Created a buffer of 11264 bytes for tensor 3.
Classifier: Created a tflite output of 4 output tensors.

И используйте его таким образом:

Object[] inputs = { your_regular_input };
tflite.runForMultipleInputsOutputs(inputs, outputProbabilityBuffers);

Выход: https://www.tensorflow.org/lite/models/object_detection/overview#output

val locations = outputs.getValue(0).asFlowArray(),
val classes = outputs.getValue(1).asFlowArray(),
val scores = outputs.getValue(2).asFlowArray(),
val detections = outputs.getValue(3).asFlowArray()

Другие вопросы по теме