python - 在 Keras 中定义新的 Lambda 层时重塑错误
问题描述
我试图在 Keras 中实现一个乘法层,但我收到了多个Reshape
相关的错误。尽管它们现在都已解决,但我仍然怀疑为什么会这样。所以这是我实现的代码块:
out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Reshape((9, 4, 15))(out2)
out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
# K.dot should be of size (-1, 4, 9, 9), so I set output 324, and later on, reshape the ata
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Reshape((-1, 4, 9, 9))(out2)
out2 = Permute((0, 2, 3, 1))(out2)
这现在工作正常。但我做了 3 件我不满意的事情:
我曾经有过,
out2 = Reshape((-1, 9, 4, 15))(out2)
但out2 = Reshape((9, 4, 15))(out2)
我有错误ValueError: Dimension must be 5 but is 4 for 'lambda_1/transpose' (op: 'Transpose') with input shapes: [?,?,9,4,15], [4].
显然,我没有考虑批量大小。
现在我尝试更正该行
out2 = Reshape((-1, 4, 9, 9))(out2)
以out2 = Reshape((4, 9, 9))(out2)
使用相同的概念,但是随后它引发了错误ValueError: total size of new array must be unchanged
我不明白不一致之处。
- 最后,我想知道删除
output_shape=(4,9,9)
是否会对代码造成任何错误。
解决方案
关于批量大小的问题,Keras 会自动处理。层代表要应用于批次的函数只是一个惯例,Keras 的任务是将此类函数应用于模型所馈送的每个批次。所以,基本上,你应该在定义层时忽略批量大小。
此外,该Dense
图层无法按预期工作。它应用于其输入的最后一个维度。如果您想从那时起将您的数据作为常规 MLP 处理,您可以像使用以常规全连接层结尾的 CNN 一样使用Flatten()
之前使用的数据(当然您可以在之后对其进行重塑)。Dense
总而言之,您可以执行以下操作:
out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Reshape((9, 4, 15))(out2)
out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
out2 = Flatten()(out2)
out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
out2 = LeakyReLU(alpha=.2)(out2)
out2 = Reshape((4, 9, 9))(out2)
推荐阅读
- office-js - 通过 Office.context.ui.displayDialogAsync() 打开 URL 时无法重现的验证失败
- python - 使用 pyaudio 生成幅度增加的音调
- javascript - Vuejs在模板中显示多个值
- java - 从 Android Studio 上的 sqlite 表中删除一行不起作用
- java - 如何将“hh:mm:ss”转换为sql更新间隔
- r - 如何从R中的逻辑回归中获得中位数优势比?
- javascript - Express res.render 渲染错误的文件
- sql - 使用 DBplyr 连接到数据库
- javascript - 尝试从 MediaWiki API 获取一条数据时出现“TypeError:无法读取未定义的属性“0”
- javascript - 为什么我 *inconsistently* 得到 DOMException: Blocked a frame with origin "https://ec2b.foo.com" from access a cross-origin frame