java - 在 resizeInput() 后运行时,Tensorflow lite 解释器崩溃
问题描述
我有一个使用的 tensorflow lite 模型,输入是{1,320,320,3} 输出数组数据类型是 {FLOAT32,FLOAT32,INT32}。将输入大小调整为{1,224,320,3}时,Interpreter#runForMultipleInputsOutputs 崩溃。
我的张量流精简版:
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
创建输入张量。准备输出张量:
inputImageBuffer = new TensorImage(imageDataType);
outputArr = new TensorBuffer[tflite.getOutputTensorCount()];
resizeOutputArr = new TensorBuffer[tflite.getOutputTensorCount()];
for (int i = 0, count = tflite.getOutputTensorCount(); i < count; i++) {
DataType pDataType = tflite.getOutputTensor(i).dataType();
if (pDataType == DataType.INT32) pDataType = DataType.FLOAT32;
outputArr[i] = TensorBuffer.createFixedSize(i == 0 ? new int[]{imageSizeY, imageSizeX} : new int[]{imageSizeX}, pDataType);
resizeOutputArr[i] = TensorBuffer.createDynamic(pDataType);
}
调整输入图像缓冲区的大小
// Loads bitmap into a TensorImage.
inputImageBuffer.load(bitmap);
float sc = bitmap.getWidth() / 320.0f;
int scaleWid = Math.round(bitmap.getWidth() / sc);
int scaleHei = Math.round(bitmap.getHeight() / sc);
imageSizeY = scaleHei / 32 * 32;
// Creates processor for the TensorImage.
int numRotation = sensorOrientation / 90;
this.processor =
new ImageProcessor.Builder()
.add(new ResizeOp(scaleHei, scaleWid, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
.add(new ResizeWithCropOrPadOp(imageSizeY, imageSizeX))
.add(new Rot90Op(numRotation))
.add(new GrayOp(imageSizeY, imageSizeX))
.add(getPreprocessNormalizeOp())
.build();
processor.process(inputImageBuffer);
跑:
int imageTensorIndex = 0;
int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3}
if (imageSizeY != imageShape[1] || imageShape[1] == 0) { // imageSizeY changed
tflite.resizeInput(imageTensorIndex, new int[]{1, imageSizeY, imageSizeX, 3}); //resize inputTensor
for (int i = 0, count = tflite.getOutputTensorCount(); i < count; i++) { //recreate outputTensor
DataType pDataType = tflite.getOutputTensor(i).dataType();
if (pDataType == DataType.INT32) pDataType = DataType.FLOAT32;
outputArr[i] = TensorBuffer.createFixedSize(i == 0 ? new int[]{imageSizeY, imageSizeX} : new int[]{imageSizeX}, pDataType);
}
}
Map<Integer, Object> outputMap = new HashMap<>();
for (int i = 0, cnt = outputArr.length; i < cnt; i++) { //rewind outputBuffer
outputMap.put(i, outputArr[i].getBuffer().rewind());
resizeOutputArr[i].getBuffer().rewind();
}
try {
tflite.runForMultipleInputsOutputs(new Object[]{inputImageBuffer.getBuffer()}, outputMap);
} catch (IllegalArgumentException e) {
e.printStackTrace();
return null;
}
for (int i = 0, count = tflite.getOutputTensorCount(); i < count; i++) { //resize outputArr with new shape
int[] shape = tflite.getOutputTensor(i).shape();
outputArr[i].getBuffer().limit(computeFlatSize(shape) * DataType.FLOAT32.byteSize());
resizeOutputArr[i].loadBuffer(outputArr[i].getBuffer(), shape);
outputArr[i].getBuffer().clear();
if (i == count - 1 && shape != null && shape[0] <= 0) return null;
}
return resizeOutputArr;
protected static int computeFlatSize(@NonNull int[] shape) {
SupportPreconditions.checkNotNull(shape, "Shape cannot be null.");
int prod = 1;
for (int i = 0; i < shape.length; ++i) {
prod *= shape[i];
}
return prod;
}
当位图中没有可识别的对象时,我得到了异常:
W: java.lang.IllegalArgumentException: Internal error: Tensor hasn't been allocated.
W: at org.tensorflow.lite.Tensor.buffer(Native Method)
W: at org.tensorflow.lite.Tensor.buffer(Tensor.java:493)
W: at org.tensorflow.lite.Tensor.copyTo(Tensor.java:264)
W: at org.tensorflow.lite.Tensor.copyTo(Tensor.java:254)
W: at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:170)
W: at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:343)
W: at com.fotric.irdetector.qf.tflite.Classifier.recognizeImage(Classifier.java:283)
这个异常会降低识别性能,有人帮帮我,谢谢!!!
解决方案
推荐阅读
- reactjs - React-Redux action.payload 未定义
- react-native - 如何在带有 Jest 的 react-native 中使用 axios() 对 API 调用进行单元测试
- python - functools的嵌套map函数下压一行
- css - 动画时区分 SVG 蒙版和原始 SVG 形状
- c - 我可以编写一个使用多种字符编码的控制台程序吗?
- perl - 通过命令行将多个输入文件传递给 Perl 脚本
- php - “包含”在“if / else”PHP 中不起作用
- laravel - JSON 是对象而不是数组,如果 array_diff 在 Collection->toArray() 上返回关联数组
- c# - 如何设置字符串 Setting.Default = Application.StartupPath 以便首次打开时?
- java - Android Studio 复选框