c - 这个 TFLite 输出对应的 C 对象的形状是什么?
问题描述
我有一个 YOLOv5 训练模型转换为 .tflite 格式并使用了本指南。
我使用此代码在 python 中打印输入和输出形状:
interpreter = tf.lite.Interpreter(
# model_path="models/exported_resnet640.tflite") # centernet_512x512 works correctly
model_path="models/yolov5_working.tflite") # centernet_512x512 works correctly
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("======================================================")
print(input_details)
print("======================================================")
# print(output_details)
for detail in output_details:
print(detail)
print(" ")
输出如下所示:
======================================================
[{'name': 'input_1', 'index': 0, 'shape': array([ 1, 480, 480, 3], dtype=int32), 'shape_signature': array([ 1, 480, 480, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
======================================================
{'name': 'Identity', 'index': 422, 'shape': array([ 1, 14175, 9], dtype=int32), 'shape_signature': array([ 1, 14175, 9], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
在给出一些输入后调用解释器后,我得到一个如下所示的输出:
Output: [[[0.01191081 0.01366316 0.02800988 ... 0.1661754 0.31489396 0.4217688 ]
[0.02396268 0.01650745 0.0442626 ... 0.24655405 0.35853994 0.2839473 ]
[0.04218047 0.01613732 0.0548977 ... 0.13136038 0.25760946 0.5338376 ]
...
[0.82626414 0.9669814 0.4534862 ... 0.18754318 0.11680853 0.18492043]
[0.8983849 0.9680944 0.64181983 ... 0.19781056 0.16431764 0.16926363]
[0.9657682 0.9869368 0.5452545 ... 0.13321301 0.12015155 0.15937251]]]
使用 Tensorflow Lite c_api.h,我试图在 C 中获得相同的输出,但我无法理解如何创建获取数据的对象。
我尝试使用float***
with size1 * 14715 * 9 * sizeof(float)
并获得如下输出:
int number_of_detections = 14175;
struct filedata o_boxes;
float ***box_coords = (float ***)malloc(sizeof(float **) * 1);
box_coords[0] = (float **)malloc(sizeof(float *) * (int)number_of_detections);
for (int i = 0; i < (int)number_of_detections; i++)
{
box_coords[0][i] = (float *)calloc(sizeof(float), 9); // box has 9 coordinates
}
o_boxes.data = box_coords;
o_boxes.size = 1 * (int)number_of_detections * 9 + 1;
const TfLiteTensor *output_tensor_boxes =
TfLiteInterpreterGetOutputTensor(interpreter, 0);
TfLiteTensorCopyToBuffer(output_tensor_boxes, o_boxes.data,
o_boxes.size * sizeof(float));
box_coords = (float ***)&o_boxes.data;
for (int i = 0; i < o_boxes.size; i++)
{
for (int j = 0; j < 9; j++)
{
printf("%f ", box_coords[0][i][j]);
fflush(stdout);
}
printf("\n");
}
哪里struct filedata
是一个简单的结构:
struct filedata
{
void *data;
size_t size;
};
结果是一些垃圾大花车:
39688651931648.000000 0.000000 39805756899328.000000 0.000000 39807166185472.000000 0.000000 39807367512064.000000 0.000000 39807568838656.000000
在第一次迭代之后,我得到一个分段错误。
我应该如何创建/分配我的浮点数组来获取我的数据?
解决方案
显然,结果都在一行中,所以我修改了代码,如下所示:
int number_of_detections = 14175;
struct filedata o_boxes;
float **box_coords = malloc(sizeof(float *) * number_of_detections);
for (int i = 0; i < number_of_detections; i++)
{
box_coords[i] = calloc((9 + 1), sizeof(float)); // box has 9 coordinates, added 1 to be sure
}
o_boxes.data = (void *)box_coords;
o_boxes.size = (number_of_detections * 9);
const TfLiteTensor *output_tensor_boxes =
TfLiteInterpreterGetOutputTensor(interpreter, 0);
TfLiteTensorCopyToBuffer(output_tensor_boxes, o_boxes.data,
o_boxes.size * sizeof(float));
box_coords = (float **)&o_boxes.data; // not entirely sure why we need & there
for (int i = 0; i < number_of_detections; i++)
{
for (int j = 0; j < 9; j++)
{
printf("%f ", box_coords[0][j + i * 9]); // we know we have 9 coordinates, and every line is 9 floats away
fflush(stdout);
}
printf("\n");
}
return 0;
}
现在就像一个魅力!
推荐阅读
- angular - 无法覆盖 MockStore
- parse-error - 解析错误:语法错误,意外的 '(',第 10 行 C:\xampp\htdocs\Expert\class\users.php 中的预期变量 (T_VARIABLE)
- headless - 现代gl无头你好世界给出了意想不到的结果
- angular - 使用 *ngFor 时如何避免 :enter/:leave 值更改时的动画?
- c# - Hololens - Unity c#:将点集从相机计划投影到统一坐标系?
- angularjs - 打字稿角文件阅读器
- symbolic-math - 如何在 Sage 中的函数中使用微分
- python - AttributeError:“HyperbandSearchCV”对象没有属性“_get_param_iterator”
- angular - 在 Angular 中将数据从搜索栏调用到 chart.js 图表
- python - 更改 QInputDialog 的 CancelButtonText