python - 如何在 Raspberry Pi 中加载 tflite 权重
问题描述
我正在使用神经网络,我需要使用 Raspberry Pi v2。
当我想安装 tensorflow 2.X 时它失败了,我只能安装 tensorflow 1.14。出于这个原因,我找到了tflite
一个理论上可以帮助我的库,它有一个精简版的 tf.
这里的图像显示我无法安装它。
首先,我将我的 keras 模型(model.h5)转换为.tflite
模型。
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
既然来了,一切OK。问题是当我想使用这个模型时。有了张量流,我知道该怎么做,
from tensorflow import keras
def importModel(myPath):
file = open(myPath+'model/model.json', 'r')
model_json = file.read(); file.close()
model = keras.models.model_from_json(model_json)
model.load_weights(myPath+'model/model.h5')
return model, scaler
但是我真的不明白该怎么做tflite
,有人可以帮助我吗?
解决方案
你可以在官方文档中找到这个
import numpy as np
import tensorflow as tf
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
如果您在树莓派上安装 TensorFlow 2.x 时遇到问题,可能是因为您没有使用最新版本的 Python3
推荐阅读
- c++ - 将节点添加到链表(分段错误)
- reactjs - 通过 Axios 发送 FormData 到 Express 应用程序
- c# - C#等待值被设置
- docker - 在 Google sdk docker 映像中安装 sshuttle
- mysql - BASH MySQL 输出到动态变量
- reactjs - 问题“调用 Redux 操作时状态值更改为未定义”
- java - java - 如何在java中将st上标为数字1?
- c - C中的双向链表,忽略最终条目,或者想要一个没有的条目
- javascript - AJAX 调用 - 动态 JSON 数据库数组 - 隐藏 HTML 元素 - 具有 CSS 可见性属性
- javascript - Mysql结果,数组到json但添加一个元素