首页 > 解决方案 > Keras中非常奇怪的Reshape层行为

问题描述

keras: 2.2.4
TF: 1.15.0

当我创建一个网络时,像这样:

print(priorbox3.shape) #got (?, 38, 38, 3, 8)
priorbox3_reshape = Reshape((38*38*3, 4))(priorbox3)
print(priorbox3_reshape.shape) #(?, 4332, 4)

它运行成功,但实际上38 * 38 * 3 * 8 != 4332 * 4! 很奇怪,如果我set Reshape((38 * 38 * 3, 8)),我会得到一个不匹配的错误。

标签: tensorflowkeras

解决方案


重塑层不检查维度完整性,因为在此基础上,它是一个 tensorflow(如果您使用 tensorflow 作为后端)占位符。开头的问号意味着它可以是任何东西。所以这意味着无论你如何重塑,它都是一个有效的形状,因为 ? 代表任何东西。在你的情况下。38x38x3x8 / 4 = 8664。但是您指定了 4332,这意味着将有 2 个维度进入重塑张量的问号。在运行时,假设您的批量大小为 4,那么您的重塑应该是(8, 4332, 4)然后 tf 将引发运行时错误。

同样,重塑层只是为了构建一个图形。在这种情况下,错误会延迟到运行时


推荐阅读