python - TFLite 量化模型仍然输出浮点数
问题描述
我有一个 CNN 已经在工作,但现在有必要将它放在一些特定的硬件中。为此,我被告知要量化模型,因为硬件只能使用整数运算。
我在这里阅读了一个很好的解决方案: 如何确保 TFLite 解释器仅使用 int8 操作?
我编写了这段代码以使其工作:
model_file = "models/my_cnn.h5"
# load data
model = tf.keras.models.load_model(model_file, custom_objects={'tf': tf}, compile=False)
# convert
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint16 # or tf.uint8
converter.inference_output_type = tf.uint16 # or tf.uint8
qmodel = converter.convert()
with open('thales.tflite', 'wb') as f:
f.write(qmodel)
interpreter = tf.lite.Interpreter(model_content=qmodel)
interpreter.allocate_tensors()
# predict
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
image = read_image("test.png")
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
当我们查看打印的输出时,我们可以看到,首先是详细信息:
input_details
[{'name': 'input_1', 'index': 87, 'shape': array([ 1, 160, 160, 3], dtype=int32), 'shape_signature': array([ 1, 160, 160, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
output_details
[{'name': 'Identity', 'index': 88, 'shape': array([ 1, 160, 160, 1], dtype=int32), 'shape_signature': array([ 1, 160, 160, 1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
量化模型的输出为:
...
[[0. ]
[0. ]
[0. ]
...
[0.00390625]
[0.00390625]
[0.00390625]]
[[0. ]
[0. ]
[0. ]
...
[0.00390625]
[0.00390625]
[0.00390625]]]]
所以,我在这里有几个问题:
在输入/输出细节中我们可以看到输入/输出层是int32,但是我在代码中指定了uint16
同样在输入/输出细节中,我们可以看到多次“float32”作为 dtype 出现,我不明白为什么。
最后,最大的问题是输出包含浮点数,这是不应该发生的。所以看起来模型并没有真正转换为整数。
我怎样才能真正量化我的 CNN 以及为什么它不能工作这个代码?
解决方案
The converter.inference_input_type
and converter.inference_output_type
support only tf.int8
or tf.uint8
, not tf.uint16
.
推荐阅读
- qt - 在 Windows 10 中使用 QSystemTrayIcon 的 showMessage 时显示应用程序名称及其扩展名
- java - 如何在 JDK 13 中创建 javaFX 程序
- spring-mvc - 在此行发现多个注释:- 引用的文件包含错误
- jquery - .hover 和 .on 与 jQuery 的问题
- java - 从 Spring Batch 自定义 ItemReader 返回对象列表
- python-3.x - configparser 无法读取常规配置文件,没有部分
- javascript - 在 React 组件中显示键不同但值相同的对象中的数据
- javascript - 从 PHP 调用 Powershell 脚本不起作用
- c# - 如何通过 web api 从 Azure Blob 存储中获取图像作为二进制文件?
- html - 选择 tr 元素为特定表格设置样式