python - TensorFlow 保存的模型不包含输入名称
问题描述
我们目前正在 tensorflow 2.4.0 中训练一个运行良好的对象检测模型。然而,为了能够为它服务,我们需要用一个图像预处理层包装它,该层将图像字节作为输入并将它们转换为检测模型所需的图像张量。请参阅以下代码:
png_file = 'myfile.png'
input_tensor = tf.io.read_file(png_file, name='image_bytes')
def preprocessing_layer(inputs):
image_tensor = tf.image.decode_image(inputs, channels=3)
image_tensor = tf.expand_dims(
image_tensor, axis=0, name=None
)
return image_tensor
model = keras.Sequential(
[
keras.Input(tensor=input_tensor, dtype=tf.dtypes.string, name='image_bytes', batch_size=1),
tf.keras.layers.Lambda(lambda inp: preprocessing_layer(inp)),
yolo_model
]
)
model.summary()
这个包装模型提供了有用的检测,如果我们调用model.input_names
正确的名称,则返回:['image_bytes']
.
现在,如果我们使用保存的模型model.save('model_path')
保存模型,则不再包含输入名称并用通用名称替换它们(args_0
)。
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['args_0'] tensor_info:
dtype: DT_STRING
shape: ()
name: serving_default_args_0:0
The given SavedModel SignatureDef contains the following output(s):
outputs['model'] tensor_info:
dtype: DT_FLOAT
shape: (1, 64512, 6)
这是一个问题,因为 tensorflow 服务依赖于以 结尾的名称_bytes
来转换 base64 输入。
您能否提供有关在保存模型时如何保留输入名称的提示?
解决方案
问题源于您定义 lambda 层的方式以及设置模型的方式。
您的 lambda 函数应该能够处理批处理,但目前并非如此。您可以天真地使用tf.map_fn
它来处理一批图像,如下所示:
def preprocessing_layer(str_inputs):
def decode(inputs):
image_tensor = tf.image.decode_image(inputs[0], channels=3)
image_tensor = tf.expand_dims(
image_tensor, axis=0, name=None
)
return image_tensor
return tf.map_fn(decode, str_inputs, fn_output_signature=tf.uint8)
然后您可以使用符号定义您的模型tf.keras.Input
,将形状设置为()
(以指定批量大小以外的任何尺寸):
model = keras.Sequential(
[
keras.Input((), dtype=tf.dtypes.string, name='image_bytes'),
tf.keras.layers.Lambda(lambda inp: preprocessing_layer(inp)),
yolo_model
]
)
现在模型已正确创建,并且可以正确导出签名。
推荐阅读
- typescript - 在打字稿中找不到名称“内容”
- python - 无法使用 Bash Operator 为 python 脚本运行 dag
- python - 用 dtype 拟合具有 numpy 变化的多项式,即使实际数据值保持不变
- pandas - Pandas Mulmtiindex 多列第 1 级
- javascript - Chrome 开发工具中的“Ot”是什么意思?
- google-chrome-extension - 在 Manifest v3 中使用地理位置
- python - python if/else 和 try/except 组合没有冗余
- android - 为什么 toPx() 不能在 Canvas 之外工作?
- r - 错误:无法产生函数的结果
- java - 添加库 org.springframework.boot.context。... 2.4.2 到每个项目模型中的类路径