python - 我可以获取 tensorflow lite 模型的指标吗?
问题描述
我一直在研究一个带有自定义指标的复杂 Keras 模型,最近我将它转换为 tensorflow lite。模型不完全相同,输出也不同,但是很难评估,因为输出是大小为 128 的张量。有什么方法可以在这个模型上运行我的自定义指标?我一直在使用 Tf 1.14。下面是一些相关的代码。
# compiler and train the model
model.save('model.h5')
# save the model in TFLite
converter = tf.lite.TFLiteConverter.from_keras_model_file('model.h5', custom_objects={'custom_metric': custom_metric})
tflite_model = converter.convert()
open('model.tflite', 'wb').write(tflite_model)
# run the model
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
input_dets = interpreter.get_input_details()
output_dets = interpreter.get_output_details()
input_shape = input_dets[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_dets[0]['index'], input_data)
interpreter.invoke()
解决方案
这些模型应该是不同的,因为转换器会进行图形转换(例如熔断器激活和折叠批量规范),并且生成的图形仅针对推理场景。
运行指标:解释器提供了一个 API 来获取输出值(作为一个数组):
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
然后将指标应用于输出。
推荐阅读
- apache-nifi - 使用嵌入式 Zookeeper 将节点添加到 NIFI 集群
- python-2.7 - 如何更改我的颜色条就像 matplotlib 中的滑块一样工作
- php - 如何在 laravel 5.6 中进行自定义登录?
- hadoop - 在 cloudera 5.13.0 服务未启动
- variables - Prestashop 1.7.4.x - 使用 order_conf 作为新电子邮件模板时缺少变量值
- javascript - vis js时间线边界上的项目被剪切/部分可见
- android - 在Firebase中划分部分时如何显示当前登录用户数据?
- android - 活动暂停时停止调用 onLocationChanged()
- amazon-web-services - 由于 Windows 防火墙,RDP 无法访问 EC2 实例
- sql-server - 如何在 SQL Server 中使用 Pivot?