tensorflow-lite - 如何使用 c++ api 在 tflite 中获取权重?
问题描述
我在设备上使用 .tflite 模型。最后一层是 ConditionalRandomField 层,我需要该层的权重来进行预测。如何使用 c++ api 获得权重?
Netron 或 flatc 不能满足我的需求。设备太重。
似乎 TfLiteNode 将权重存储在 void* user_data 或 void* builtin_data 中。我如何阅读它们?
更新:
结论:.tflite 不存储 CRF 权重,而 .h5 剂量。(也许是因为它们不影响输出。)
我所做的:
// obtain from model.
Interpreter *interpreter;
// get the last index of nodes.
// I'm not sure if the index sequence of nodes is the direction which tensors or layers flows.
const TfLiteNode *node = &((interpreter->node_and_registration(interpreter->nodes_size()-1))->first);
// then follow the answer of @yyoon
解决方案
在 TFLite 节点中,权重应该存储在inputs
数组中,其中包含对应的索引TfLiteTensor*
。
所以,如果你已经获得了TfLiteNode*
最后一层的,你可以做这样的事情来读取权重值。
TfLiteContext* context; // You would usually have access to this already.
TfLiteNode* node; // <obtain this from the graph>;
for (int i = 0; i < node->inputs->size; ++i) {
TfLiteTensor* input_tensor = GetInput(context, node, i);
// Determine if this is a weight tensor.
// Usually the weights will be memory-mapped read-only tensor
// directly baked in the TFLite model (flatbuffer).
if (input_tensor->allocation_type == kTfLiteMmapRo) {
// Read the values from input_tensor, based on its type.
// For example, if you have float weights,
const float* weights = GetTensorData<float>(input_tensor);
// <read the weight values...>
}
}
推荐阅读
- python - 切换到 64 位时出现 Python Winerror 193
- javascript - SyntaxError:输入 Discord Bot 意外结束
- java - 使用 Jsoup 非递归提取文本
- javascript - 如何在firebase登录或登录中设置或添加自定义提供程序名称,例如linkedin、instagram?
- c# - ScrapySharp 表单提交导致 System.AggregateException
- c# - 为什么将数据导出到 Excel 时日期格式会发生变化?
- python-3.x - Python中的闭包 - 执行内部函数?
- c# - 左上原点的 OpenTK 正交投影
- amp-html - 如何在 amp 电子邮件中收集和存储动态数据?
- shell - 输出到文件时shell换行的问题