python - Keras模型的训练后全整数量化
问题描述
我正在尝试对 Keras 模型进行训练后全 8 位量化,以编译并部署到 EdgeTPU。我有一个保存为 .h5 文件的训练有素的 Keras 模型,并且正在尝试按照此处指定的步骤进行操作:https ://coral.withgoogle.com/docs/edgetpu/models-intro/ ,以部署到 Coral 开发板.
我正在遵循这些量化说明:https ://www.tensorflow.org/lite/performance/post_training_quantization#full_integer_quantization_of_weights_and_activations )
我正在尝试使用以下代码:
import tensorflow as tf
num_calibration_steps = 100
def representative_dataset_gen():
for _ in range(num_calibration_steps):
# Get sample input data as a numpy array in a method of your choosing.
yield [X_train_quant_conv]
converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file('/tmp/classNN_simple.h5')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.representative_dataset = representative_dataset_gen
tflite_full_integer_quant_model = converter.convert()
我的训练数据的子集在哪里X_train_quant_conv
转换为np.array
类型np.float32
运行这段代码时,我收到以下错误:
ValueError: Cannot set tensor: Dimension mismatch
我尝试representative_dataset_gen()
以不同的方式更改功能,但每次我遇到新错误。我不确定这个功能应该如何。我也怀疑什么 value num_calibration_steps should have
。
非常感谢任何建议或工作示例。
这个问题与这个回答的问题非常相似:Convert Keras model to quantized Tensorflow Lite model that can be used on Edge TPU
解决方案
您可能想在 github 上查看我的量化演示脚本。
这只是一个猜测,因为我看不到X_train_quant_conv
真正的内容,但在我的工作演示中,我一次生成一个图像(在我的情况下是动态创建的随机数据)在representative_dataset_gen()
. 图像存储为大小为 1 的批次(例如,对于我的 52x52x32 图像,张量形状为 (1, 56, 56, 32))。彩色图像有 32 个通道,但通常只有 3 个。我认为representative_dataset_gen()
必须产生一个包含张量(或多个?)的列表,其中第一个维度的长度为 1。
image_shape = (56, 56, 32)
def representative_dataset_gen():
num_calibration_images = 10
for i in range(num_calibration_images):
image = tf.random.normal([1] + list(image_shape))
yield [image]
推荐阅读
- php - 未定义路由登录(多身份验证)
- azure-data-factory - 如何在存储帐户中使用 azure 数据工厂迁移表,该表具有多种类型
- if-statement - 如果用户没有为 Active Choice Parameterized 管道构建输入值,则退出管道
- c# - 为什么我不能在剃刀条件语句中使用变量
- python - 这段代码怎么能写在一行里?如何写得更 Python 化?
- java - 致命异常:java.lang.IllegalStateException - 无法为 LinearLayout 创建层(仅在 Galaxy j4+、j6+ 中崩溃)
- django - 在 Django 的 url 中具有主键的 Ajax 无法工作
- c++ - char* 和 std::string 和 const char* 之间的转换
- javascript - 有没有办法将滚动移动到已搜索的列
- linux - 在 Linux 内核 Makefile 中回显字符串