c++ - tensorflow lite中的跨步切片错误
问题描述
我正面临 tensorflow-lite 的问题。我收到此错误:
不支持类型 INT32 (2)。节点 STIDED_SLICE(编号 2)调用失败,状态为 1
我所做的是:
我用 MNIST 数据训练了一个模型。
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
我使用integer-only quantization转换了模型。但是,当我调用模型时,它会引发该错误。
我正在查看 striced_slice.cc,我发现了这个:
switch (output->type) {
case kTfLiteFloat32:
reference_ops::StridedSlice(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteUInt8:
reference_ops::StridedSlice(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
break;
case kTfLiteInt8:
reference_ops::StridedSlice(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
所以不支持int32。我不太确定如何处理这种问题。有没有办法改变这个节点上的行为?我应该在量化步骤中做一些不同的事情吗?
我所做的是:
def representative_data_gen():
for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
yield [input_value]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
open("model_int8.tflite", "wb").write(tflite_model)
PD:我正在使用 tensorflow-lite 以在 stm32 中使用。
提前致谢。
解决方案
当您执行全整数量化时,您的输入和输出应为 1 字节长(在您的情况下为 int8)。提供 int8 值作为输入,您将能够调用您的模型。
推荐阅读
- node.js - 如果目标类被扩展,TypeScript 合并声明组合不起作用
- google-cloud-platform - 什么是 BigQuery DML 配额限制
- python - TypeError: __init__() 接受 1 个位置参数,但给出了 2 个(Python multiprocessing with Pytesseract)
- javascript - 查找输入字段的单元格 ID?
- android - Flipper 网络插件不显示网络请求
- svg - 三个 js 线框球体 svg
- r - 刻面和 scale_color_manual
- linux - 在带有 OpenGL 的 Buildroot 系统上找不到 libgl.so
- python - 如何在manim中为扇区的角度和旋转同时变化设置动画
- php - 排序数组的缺失键 - rsort - PHP