首页 > 解决方案 > tensorflow lite中的跨步切片错误

问题描述

我正面临 tensorflow-lite 的问题。我收到此错误:

不支持类型 INT32 (2)。节点 STIDED_SLICE(编号 2)调用失败,状态为 1

我所做的是:

我用 MNIST 数据训练了一个模型。

  model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

我使用integer-only quantization转换了模型。但是,当我调用模型时,它会引发该错误。

我正在查看 striced_slice.cc,我发现了这个:

      switch (output->type) {
        case kTfLiteFloat32:
          reference_ops::StridedSlice(op_params,
                                      tflite::micro::GetTensorShape(input),
                                      tflite::micro::GetTensorData<float>(input),
                                      tflite::micro::GetTensorShape(output),
                                      tflite::micro::GetTensorData<float>(output));
          break;
        case kTfLiteUInt8:
          reference_ops::StridedSlice(
              op_params, tflite::micro::GetTensorShape(input),
              tflite::micro::GetTensorData<uint8_t>(input),
              tflite::micro::GetTensorShape(output),
              tflite::micro::GetTensorData<uint8_t>(output));
          break;
        case kTfLiteInt8:
          reference_ops::StridedSlice(op_params,
                                      tflite::micro::GetTensorShape(input),
                                      tflite::micro::GetTensorData<int8_t>(input),
                                      tflite::micro::GetTensorShape(output),
                                      tflite::micro::GetTensorData<int8_t>(output));
          break;
        default:
          TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
                             TfLiteTypeGetName(input->type), input->type);

所以不支持int32。我不太确定如何处理这种问题。有没有办法改变这个节点上的行为?我应该在量化步骤中做一些不同的事情吗?

我所做的是:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

tflite_model = converter.convert()
open("model_int8.tflite", "wb").write(tflite_model)

PD:我正在使用 tensorflow-lite 以在 stm32 中使用。

提前致谢。

标签: c++tensorflow-litequantization

解决方案


当您执行全整数量化时,您的输入和输出应为 1 字节长(在您的情况下为 int8)。提供 int8 值作为输入,您将能够调用您的模型。


推荐阅读