python - Tf Lite 模型图像分类打印标签
问题描述
我正在研究 Image Claasification TF Lite 模型,以使用此链接检测人脸的蒙版或不蒙版。我按照链接在顶点 AI 中训练了图像多类分类并下载了 TF lite 模型。模型的标签是“mask”和“no_mask”。为了测试模型,我编写了以下代码:
interpret= tf.lite.Interpreter(model_path="<FILE_PATH>")
input= interpret.get_input_details()
output= interpret.get_output_details()
interpret.allocate_tensors()
pprint(input)
pprint(output)
data= cv2.imread("file.jpeg")
new_image= cv2.resize(data,(224,224))
interpret.resize_tensor_input(input[0]["index"],[1,224,224,3])
interpret.allocate_tensors()
interpret.set_tensor(input[0]["index"],[new_image])
interpret.invoke()
result= interpret.get_tensor(output[0]['index'])
print (" Prediction is - {}".format(result))
将此代码用于我的一张图片给我的结果是:
[[30 246]]
现在我也想在结果中打印标签。例如:
面具:30
无掩码:46
有什么办法可以实现吗?
请帮忙,因为我是 TF Lite 的新手
解决方案
我自己解决了。从 Vertex AI 下载的 .tflite 模型包含名为“dict.txt”的标签文件,其中包含所有标签。在此处查看 GCP 文档。要获取此标签文件,我们首先需要解压缩 .tflite 文件,该文件将为我们提供 dict.txt。有关更多信息,请查看tflite 文档以及如何从模型中读取关联文件。
之后,我从github 链接 label.py中执行了以下代码:
import argparse
import time
import numpy as np
from PIL import Image
import tensorflow as tf
interpret= tf.lite.Interpreter(model_path="<FILE_PATH>")
input= interpret.get_input_details()
output= interpret.get_output_details()
interpret.allocate_tensors()
pprint(input)
pprint(output)
data= cv2.imread("file.jpeg")
new_image= cv2.resize(data,(224,224))
interpret.resize_tensor_input(input[0]["index"],[1,224,224,3])
interpret.allocate_tensors()
interpret.set_tensor(input[0]["index"],[new_image])
interpret.invoke()
floating_model= input[0]['dtype'] == np.float32
op_data= interpret.get_tensor(output[0]['index'])
result= np.squeeze(op_data)
top_k=result.agrsort()[-5:][::1]
labels=load_labels("dict.txt")
for i in top_k:
if floating_model:
print('{:08.6f}: {}'.format(float(result[i]), labels[i]))
else:
print('{:08.6f}: {}'.format(float(result[i] / 255.0), labels[i]))
推荐阅读
- javascript - 如何在 markdownit 插件中链接图像
- bash - 无法将差异输出拆分为数组
- java - 带有 Docker 的 Spring 应用程序在启动时因类路径错误而失败
- javascript - 有没有办法在角度应用程序中获取文档的高度?
- c# - 在 C# Form 应用程序中的用户控件之间传递数据
- java - 如何在 Fluent BPMN API 中创建关联流?
- python - 如何使 Python LightGBM 代码接受列表
- python - Pandas groupby 中的计数频率
- kubernetes - 为 Kubernetes 作业提供文件
- python - 将当前元素与列表中的任何元素进行比较