tensorflow - 为什么使用 Google AutoML 的浏览器中的结果与我的 android 应用程序中导出的 tflite 文件如此不同?
问题描述
我正在使用 Google AutoML 训练模型。然后我导出了一个 TFLite 模型以在 Android 应用程序上运行。
在应用程序和浏览器中使用相同的图像(Google AutoML 让您有机会在浏览器中测试您的模型),我在浏览器中得到了非常准确的答案,但在我的应用程序中,我得到了一个不太准确的答案.
关于为什么会发生这种情况的任何想法?
相关代码:
public void predict(String imagePath, final Promise promise) {
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new byte[1][RESULTS_TO_SHOW];
try {
labelList = loadLabelList();
} catch (Exception ex) {
ex.printStackTrace();
}
Log.w("FIND_ ", imagePath);
Bitmap bitmap = BitmapFactory.decodeFile(imagePath);
convertBitmapToByteBuffer(bitmap);
try {
tflite = new Interpreter(loadModelFile());
} catch (Exception ex) {
Log.w("FIND_exception in loading tflite", "1");
}
tflite.run(imgData, labelProbArray);
promise.resolve(getResult());
}
private WritableNativeArray getResult() {
WritableNativeArray result = new WritableNativeArray();
//ArrayList<JSONObject> result = new ArrayList<JSONObject>();
try {
for (int i = 0; i < RESULTS_TO_SHOW; i++) {
WritableNativeMap map = new WritableNativeMap();
map.putString("label", Integer.toString(i));
float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
map.putString("prob", String.valueOf(output));
result.pushMap(map);
Log.w("FIND_label ", Integer.toString(i));
Log.w("FIND_prob ", String.valueOf(labelProbArray[0][i]));
}
} catch (Exception ex) {
ex.printStackTrace();
}
return result;
}
private List<String> loadLabelList() throws IOException {
Activity activity = getCurrentActivity();
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}
private void convertBitmapToByteBuffer(Bitmap bitmap) {
if (imgData == null) {
return;
}
imgData.rewind();
Log.w("FIND_bitmap width ", String.valueOf(bitmap.getWidth()));
Log.w("FIND_bitmap height ", String.valueOf(bitmap.getHeight()));
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
// Convert the image to floating point.
int pixel = 0;
//long startTime = SystemClock.uptimeMillis();
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
imgData.put((byte) ( ((val >> 16) & 0xFF)));
imgData.put((byte) ( ((val >> 8) & 0xFF)));
imgData.put((byte) ( (val & 0xFF)));
Log.w("FIND_i", String.valueOf(i));
Log.w("FIND_j", String.valueOf(j));
}
}
}
private MappedByteBuffer loadModelFile() throws IOException {
Activity activity = getCurrentActivity();
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
我不知道是否正确的一件事是我如何将输出数据调整为概率。在getResult()
方法中,我这样写:
float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
我除以 255 是为了得到一个介于 0 和 1 之间的数字,但我不知道这是否正确。
如果有帮助,.tflite 模型是量化的 uint8 模型。
解决方案
推荐阅读
- ios - 未找到 -lboost_context 的库
- python - 无效的块标记“其他”。您是否忘记注册或加载此标签?
- wordpress - Wordpress 上的规范页面
- automation - Ansible cisco ios,改变接口上的vlan
- javascript - Javascript - 寻找一种转换父子关系的方法
- elasticsearch - 由于控制进程收到致命信号(code=killed, signal=9/KILL),无法启动 elasticsearch
- javascript - 未捕获的类型错误:无法读取 null 的属性“getAttribute”(传单控制层)
- python-3.x - Dask Distributed 2021.4.1 序列化失败
- arrays - 使用 JSONArray 作为输入的 Spark Streaming [Pyspark]
- android - 基于网站在Android App上启用WebRTC