首页 > 解决方案 > Keras Upsampling2d -> tflite 转换导致形状推断失败和未定义的输出形状

问题描述

Keras Upsampling2d 操作通过附加操作和未定义形状转换为 this

在此处输入图像描述

然而,Tensorflow 在没有此操作的情况下会以正确的形状进行转换

在此处输入图像描述

这会导致未定义的整体模型输出形状并导致设备上的错误。如何解决这个问题?

标签: tensorflowkerastensorflow-lite

解决方案


此处描述了此行为https://github.com/tensorflow/tensorflow/issues/45090

默认情况下,Keras 将动态批量大小设置为 true。这意味着模型输入形状是 [*,28,28] 而不是 [1,28,28]。旧的(不推荐使用的)转换器用于忽略动态批处理并将其覆盖为 1 - 这是错误的,因为这不是原始模型所具有的 - 您可以想象当您尝试在运行时调整输入大小时会有多糟糕。

当前的转换器改为正确处理动态批量大小,并且生成的模型可以在运行时正确调整大小。这就是为什么“Shape、StridedSlice、Pack”的顺序不是恒定折叠的原因,因为形状取决于运行时定义的形状。

对于单输入模型,这可以通过在保存之前为 keras 模型设置恒定形状来修复

model.input.set_shape(1 + model.input.shape[1:])

推荐阅读