c++ - 尝试将 TensorFlow C++ API 与 Cast 操作(Movenet 模型)一起使用会导致推理时出错
问题描述
您好我正在尝试将新的 MoveNet 模型与 tensorlow C++ API 一起使用。
这是我尝试加载 SavedModel 和图像的代码的一部分。
加载模型
// Load
SavedModelBundleLite model_bundle;
SessionOptions session_options = SessionOptions();
session_options.config.mutable_gpu_options()->set_allow_growth(true);
RunOptions run_options = RunOptions();
std::string export_dir = "/workspace/modelZoo";
Status status = LoadSavedModel(session_options, run_options, export_dir, {"serve"}, &model_bundle);
加载图像
//Load image
int32_t input_width = 256;
int32_t input_height = 256;
using namespace ::tensorflow::ops;
std::string filename = "image.jpg";
Scope root = Scope::NewRootScope();
auto output = tensorflow::ops::ReadFile(root.WithOpName("file_reader"), filename);
const int wanted_channels = 3;
tensorflow::Output image_reader = tensorflow::ops::DecodeJpeg(root.WithOpName("file_decoder"), output, tensorflow::ops::DecodeJpeg::Channels(wanted_channels));
auto image_int32 = tensorflow::ops::Cast(root.WithOpName("int32_caster"), image_reader, tensorflow::DT_INT32);
auto dims_expander = tensorflow::ops::ExpandDims(root.WithOpName("expand_dims"), image_int32, 0);
auto resized = tensorflow::ops::ResizeBilinear(root, dims_expander, tensorflow::ops::Const(root.WithOpName("size"), {input_height, input_width}));
std::vector<Tensor> out_tensors;
ClientSession session(root);
auto run_status = session.Run({resized}, &out_tensors);
和推理
const std::string input_name_1 = model_def.inputs().at("input").name();
std::vector<std::pair<string, Tensor>> inputs_data = {{input_name_1, out_tensors[0]}};
std::string outputLayer = model_def.outputs().at("output_0").name();
std::vector<Tensor> outputs;
Status runStatus = model_bundle.GetSession()->Run(inputs_data, {outputLayer}, {}, &outputs);
当代码运行推理时,我收到此错误。
期望 arg[0] 为 int32 但提供了浮点数
根据来自 tensorflow hub 的教程(https://tfhub.dev/google/movenet/multipose/lightning/1),模型需要 int32 格式,但我不知道我的代码有什么问题,指的是铸造部分输入张量
解决方案
推荐阅读
- reactjs - “number[]”类型的参数不能分配给“SetStateAction”类型的参数
' - python - 忽略要在 pytest 中测试的函数
- ios - xcode12 问题:ld :为 iOS 模拟器构建,但链接到为 iOS 构建的目标文件,架构 arm64 的文件 'xxx.framework/xxx'
- android - 在 Android 中设置相机焦距
- fortran - 如何在 OpenVMS Fortran 中获取命令行参数?
- python - 正则表达式程序从包含 2 个数字且它们不相邻的句子中打印单词
- javascript - REACT : 当子组件出错时隐藏父组件
- azure-active-directory - 在 AAD 组中添加外部用户以访问 Power BI 工作区
- twitter-bootstrap - 容器没有相同的边距
- mysql - MySQL 5.7 如何将主查询数据传入子子查询