tensorflow - 使用 C++ API 使用 parsing_serving_input_receiver_fn 推断 Tensorflow lite 模型的示例
问题描述
我已按照 Tensorflow2 文档将训练有素的 tf.estimator 模型转换为 tflite 模型;为了转换我的模型,首先我必须使用 input_receiver_fn 将模型保存为 saved_model 格式,然后使用 SELECT_OPS 标志进行转换:
classifier = tf.estimator.LinearClassifier(n_classes=2, model_dir = classifier_dir, feature_columns=features)
classifier.train(input_fn = lambda: trian_fn(features = train_datas, labels = trian_labels))
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(tf.feature_column.make_parse_example_spec(features))
classifier.export_saved_model(classifier_dir+"\saved_model", serving_input_fn)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir = saved_model_dir , signature_keys=['serving_default'])
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
我想在没有 python 支持的 ARM 设备上运行我的 tflite 模型,所以我用 Bazel 构建了 C++ 解释器共享库,如文档中所述:
我的模型有 3 个输入特征,但是当我尝试使用以下指南进行推理时,我遇到了分段错误。我使用以下代码来提取我的模型详细信息:
interpreter = tf.lite.Interpreter(model_path="./model.tflite")
interpreter.allocate_tensors()
print("all ok")
# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))
# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
我得到以下输出:
all ok
1 input(s):
[1] <class 'numpy.bytes_'>
2 output(s):
[1 2] <class 'numpy.bytes_'>
[1 2] <class 'numpy.float32'>
tflite::PrintInterpreterState(interpreter.get()) 输出的前几行是:
INFO: Created TensorFlow Lite delegate for select TF ops.
INFO: TfLiteFlexDelegate delegate: 1 nodes delegated out of 25 nodes with 1 partitions.
Interpreter has 54 tensors and 26 nodes
Inputs: 0
Outputs: 38 34
Tensor 0 input_example_tensor kTfLiteString kTfLiteDynamic 0 bytes ( 0.0 MB) 1
输出说明输入形状与原始模型不同,输入类型也是 <class 'numpy.bytes_'> 但 Tensorflow 2 模型输入是 [numpy.float32, numpy.float32, numpy.float32]。我在 TF2 模型中用于预测的输入字典类似于: {'feature0' : data0, 'feature1' : data1, 'feature2' : data2}
这是 Tensorflow 模型的 Google Colab链接我以前没有推理 TensorFlow Lite 模型的经验,所以我首先搜索并发现了这些相关问题,这些问题帮助我编写了以下 C++ 代码:
用于推理的 TensorFlow Lite C++ API 示例
我试图用零向量填充输入缓冲区,但没有成功。这是我的 C++ 代码,用于加载 tflite 模型并为其提供输入以进行预测。有人可以指点我正确的方向吗,因为我找不到任何示例或相关文档来使用serving_input_fn向转换后的tf.estimator提供输入。
#include <cstdio>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
int main()
{
// Load model
std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile("model.tflite");
// Build the interpreter with the InterpreterBuilder.
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder builder(*model, resolver);
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
tflite::PrintInterpreterState(interpreter.get());
// Allocate tensor buffers.
interpreter->AllocateTensors();
printf("=== Pre-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get());
// Fill input buffers
std::vector<float> tensor(3, 0); //Vector of zeros
int input = interpreter->inputs()[0];
float* input_data_ptr = interpreter->typed_input_tensor<float>(input);
for(int i = 0; i < 3; ++i)
{
*(input_data_ptr) = (float)tensor[i];
input_data_ptr++;
}
// Run inference
interpreter->Invoke();
printf("\n\n=== Post-invoke Interpreter State ===\n");
return 0;
}
编辑1:
我还在 Tensorflow 的 GitHub 上问了这个问题,并得到了一条评论,提到我必须以“示例原型”的形式提供我的输入,现在问题被简化为什么是“示例原型”以及如何将输入提供给来自示例原型的 tflite 模型?
解决方案
推荐阅读
- mysql - 印刷店数据库设计
- machine-learning - 我如何使用遗传算法表示染色体?
- tensorflow - Tensorflow 以兼容的方式保存和恢复模型(渴望和图形模式)
- ubuntu - 在没有互联网的情况下在 ubuntu 上编译 gcc7 和 gcc
- python - django 中的 Python 代码打印当前版权年份
- docker - 如何从 docker 容器访问主机 docker 实例
- php - 使用 php 快速更新多个具有不同值的 mysql 行(大约 100 毫秒)
- generics - FSharpPlus divRem - 它是如何工作的?
- php - :checked 在 css 中不工作
- google-bigquery - 在 Google BigQuery 中,如何获取时间分区表中分区的存储大小?