首页 > 解决方案 > 使用 Tensorflow Lite [Python/Flutter 集成] 获取正确的对象检测标签

问题描述

我正在尝试在 Flutter 上使用 MobileNetV2 模型实现对象检测。由于 Flutter 应用程序在线提供的大多数示例或实现都没有使用 MobileNetV2,所以我走了很长的路才到达那个阶段。

我实现这一目标的方式如下:

1) 创建了一个 python 脚本,我在其中使用 Keras(后端 Tensorflow)的 MobileNetV2 模型(在 ImageNet 上针对 1000 个类进行了预训练),并用图像对其进行了测试,以查看它在正确检测到对象后是否返回了正确的标签。[下面提供的Python脚本供参考]

2)将相同的MobileNetV2 keras模型(MobileNetV2.h5)转换为Tensorflow Lite模型(MobileNetV2.tflite)

3) 按照现有的创建 Flutter 应用程序的示例来使用 Tensorflow Lite ( https://itnext.io/working-with-tensorflow-lite-in-flutter-f00d733a09c3 )。将示例中显示的 TFLite 模型替换为 MobileNetV2.tflite 模型,并使用https://gist.github.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57中的 ImageNet 类/标签作为 labels.txt。[此处提供 Flutter 示例的 GitHub 项目:https://github.com/umair13adil/tensorflow_lite_flutter]

当我现在运行 Flutter 应用程序时,它运行时没有任何错误,但是在分类/预测标签期间,输出不正确。例如:它将橙子(对象 id:n07747607)分类为 poncho(对象 id:n03980874),将石榴(对象 id:n07768694)分类为 banded_gecko(对象 id:n01675722)。

但是,如果我使用相同的图片并使用我的 python 脚本对其进行测试,它会返回正确的标签。所以,我想知道问题是否真的出在 Flutter 应用程序中使用的 label.txt(保存标签),其中标签的顺序与模型的推断不匹配。

谁能提到我如何解决问题以对正确的对象进行分类?如何获取 MobileNetV2 (keras) 使用的 ImageNet 标签,以便我可以在 Flutter 应用程序中使用它?

我的使用 MobileNetv2 检测对象的 Flutter App 可以从这里下载:https ://github.com/somdipdey/Tensorflow_Lite_Object_Detection_Flutter_App

我的 python 脚本将 MobileNetV2 模型 (keras) 转换为 TFLite,同时在图像上对其进行分类测试,如下所示:

import tensorflow as tf
from tensorflow import keras

from keras.preprocessing import image
from keras.applications.mobilenet_v2 import preprocess_input, decode_predictions
import numpy as np

import PIL
from PIL import Image
import requests
from io import BytesIO


# load the model
model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=True)

#model = tf.keras.models.load_model('MobileNetV2.h5')

# To save model
model.save('MobileNetV2.h5')

# chose the URL image that you want
URL = "https://images.unsplash.com/photo-1557800636-894a64c1696f?ixlib=rb-1.2.1&w=1000&q=80"
# get the image
response = requests.get(URL)
img = Image.open(BytesIO(response.content))
# resize the image according to each model (see documentation of each model)
img = img.resize((224, 224))

##############################################
# if you want to read the image from your PC
#############################################
# img_path = 'myimage.jpg'
# img = image.load_img(img_path, target_size=(299, 299))
#############################################



# convert to numpy array
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

features = model.predict(x)

# return the top 10 detected objects
num_top = 10
labels = decode_predictions(features, top=num_top)
print(labels)

#load keras model
new_model= tf.keras.models.load_model(filepath="MobileNetV2.h5")
# Create a converter # I could also directly use keras model instead of loading it again
converter = tf.lite.TFLiteConverter.from_keras_model(new_model)
# Convert the model
tflite_model = converter.convert()
# Create the tflite model file
tflite_model_name = "MobileNetV2.tflite"
open(tflite_model_name, "wb").write(tflite_model)

标签: pythontensorflowflutterkerasobject-detection

解决方案


让我首先以JSONtxt两种格式共享 ImageNet 标签。鉴于 MobileNetV2 是在 ImageNet 上训练的,它应该根据这些标签返回结果。

我最初的想法是管道的第二步一定有错误。我假设您正在尝试将经过训练的基于 Keras 的权重转换为 Tensorflow Lite 权重(与纯 Tensorflow 的格式是否相同?)。一个不错的选择是尝试以 Tensorflow Lite 的格式查找已保存的权重(但我猜它们可能不可用,这就是您进行转换的原因)。我在将 TF 权重转换为 Keras 时遇到了类似的问题,因此您必须在进入第 3 步(创建 Flutter 应用程序以使用 Tensorflow Lite)之前确定转换是否成功完成。实现这一点的一个好方法是打印分类器的所有可用类,并将它们与上面给出的原始 ImageNet 标签进行比较。


推荐阅读