首页 > 解决方案 > Tf Lite 模型图像分类打印标签

问题描述

我正在研究 Image Claasification TF Lite 模型,以使用此链接检测人脸的蒙版或不蒙版。我按照链接在顶点 AI 中训练了图像多类分类并下载了 TF lite 模型。模型的标签是“mask”和“no_mask”。为了测试模型,我编写了以下代码:

interpret= tf.lite.Interpreter(model_path="<FILE_PATH>")
input= interpret.get_input_details()
output= interpret.get_output_details()

interpret.allocate_tensors()

pprint(input)
pprint(output)

data= cv2.imread("file.jpeg")
new_image= cv2.resize(data,(224,224))

interpret.resize_tensor_input(input[0]["index"],[1,224,224,3])
interpret.allocate_tensors()
interpret.set_tensor(input[0]["index"],[new_image])
interpret.invoke()
result= interpret.get_tensor(output[0]['index'])

print (" Prediction is - {}".format(result))

将此代码用于我的一张图片给我的结果是:

[[30 246]]

现在我也想在结果中打印标签。例如:

面具:30

无掩码:46

有什么办法可以实现吗?

请帮忙,因为我是 TF Lite 的新手

标签: pythontensorflowtensorflow-liteimage-classification

解决方案


我自己解决了。从 Vertex AI 下载的 .tflite 模型包含名为“dict.txt”的标签文件,其中包含所有标签。在此处查看 GCP 文档。要获取此标签文件,我们首先需要解压缩 .tflite 文件,该文件将为我们提供 dict.txt。有关更多信息,请查看tflite 文档以及如何从模型中读取关联文件

之后,我从github 链接 label.py中执行了以下代码:

import argparse
import time

import numpy as np
from PIL import Image
import tensorflow as tf
interpret= tf.lite.Interpreter(model_path="<FILE_PATH>")
input= interpret.get_input_details()
output= interpret.get_output_details()

interpret.allocate_tensors()

pprint(input)
pprint(output)

data= cv2.imread("file.jpeg")
new_image= cv2.resize(data,(224,224))

interpret.resize_tensor_input(input[0]["index"],[1,224,224,3])
interpret.allocate_tensors()
interpret.set_tensor(input[0]["index"],[new_image])
interpret.invoke()
floating_model= input[0]['dtype'] == np.float32
op_data= interpret.get_tensor(output[0]['index'])
result= np.squeeze(op_data)

top_k=result.agrsort()[-5:][::1]
labels=load_labels("dict.txt")
for i in top_k:
   if floating_model:
      print('{:08.6f}: {}'.format(float(result[i]), labels[i]))
    else:
      print('{:08.6f}: {}'.format(float(result[i] / 255.0), labels[i]))


推荐阅读