tensorflow - 将经过训练的 HDF5 模型加载到 Rust 中以进行预测
问题描述
我使用 MNIST 数据集训练了一个模型来识别数字。该模型已使用 TensorFlow 和 Keras 在 Python 中进行了训练,并将输出保存到我命名为“sample_mnist.h5”的 HDF5 文件中。
我想将经过训练的模型从 HDF5 文件加载到 Rust 中以进行预测。
在 Python 中,我可以从 HDF5 生成模型并使用代码进行预测:
model = keras.models.load_model("./sample_mnist.h5")
model.precict(test_input) # assumes test_input is the correct input type for the model
这个 Python 片段的 Rust 等价物是什么?
解决方案
首先,您需要将模型保存为.pb格式,而不是.hdf5,以便将其移植到 Rust,因为这种格式保存了在 Python 之外重建模型所需的关于模型执行图的所有内容。TensorFlow Rust repo 上有一个来自用户justnoxx的开放拉取请求,展示了如何为简单模型执行此操作。要点是在 Python 中给出了一些模型......
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
classifier = Sequential()
classifier.add(Dense(5, activation='relu', name="test_in", input_dim=5)) # Named input
classifier.add(Dense(5, activation='relu'))
classifier.add(Dense(1, activation='sigmoid', name="test_out")) # Named output
classifier.compile(optimizer ='adam', loss='binary_crossentropy', metrics=['accuracy'])
classifier.fit([[0.1, 0.2, 0.3, 0.4, 0.5]], [[1]], batch_size=1, epochs=1);
classifier.save('examples/keras_single_input_saved_model', save_format='tf')
以及我们命名的输入“test_in”和输出“test_out”以及它们的预期大小,我们可以在 Rust 中应用保存的模型......
use tensorflow::{Graph, SavedModelBundle, SessionOptions, SessionRunArgs, Tensor};
fn main() {
// In this file test_in_input is being used while in the python script,
// that generates the saved model from Keras model it has a name "test_in".
// For multiple inputs _input is not being appended to signature input parameter name.
let signature_input_parameter_name = "test_in_input";
let signature_output_parameter_name = "test_out";
// Initialize save_dir, input tensor, and an empty graph
let save_dir =
"examples/keras_single_input_saved_model";
let tensor: Tensor<f32> = Tensor::new(&[1, 5])
.with_values(&[0.1, 0.2, 0.3, 0.4, 0.5])
.expect("Can't create tensor");
let mut graph = Graph::new();
// Load saved model bundle (session state + meta_graph data)
let bundle =
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, save_dir)
.expect("Can't load saved model");
// Get the session from the loaded model bundle
let session = &bundle.session;
// Get signature metadata from the model bundle
let signature = bundle
.meta_graph_def()
.get_signature("serving_default")
.unwrap();
// Get input/output info
let input_info = signature.get_input(signature_input_parameter_name).unwrap();
let output_info = signature
.get_output(signature_output_parameter_name)
.unwrap();
// Get input/output ops from graph
let input_op = graph
.operation_by_name_required(&input_info.name().name)
.unwrap();
let output_op = graph
.operation_by_name_required(&output_info.name().name)
.unwrap();
// Manages inputs and outputs for the execution of the graph
let mut args = SessionRunArgs::new();
args.add_feed(&input_op, 0, &tensor); // Add any inputs
let out = args.request_fetch(&output_op, 0); // Request outputs
// Run model
session.run(&mut args) // Pass to session to run
.expect("Error occurred during calculations");
// Fetch outputs after graph execution
let out_res: f32 = args.fetch(out).unwrap()[0];
println!("Results: {:?}", out_res);
}
推荐阅读
- java - Maven:将特定类型的工件从多个工件部署到存储库管理器(例如 Nexus)
- c# - 试图学习如何从 c# 中的表单获取用户输入的数据以显示到标签上,Visual Studio 没有显示错误,但标签不会更新
- python - ValueError:操作数无法与形状一起广播 (2,6) (6,2)
- python - Python C API:如何刷新标准输出和标准错误?
- python - 使用外部程序打开文件,但保持当前窗口处于活动状态
- arrays - 在 API 调用后更改元素的状态以显示组件数组
- swift - 如何使用 Swift 代码在 Mac OS 的屏幕上打开键盘
- javascript - JQuery/Bootpag 分页没有将下一页返回到结果顶部
- c# - 非空对象上的 JToken.FromObject
- mongoose - 如何使用 Mongoose 虚拟的价值?