python - Keras 2D 输入到 2D 输出
问题描述
首先,我已经阅读了这个和这个与我的名字相似的问题,但仍然没有答案。
我想为序列预测建立一个前馈网络。(我意识到 RNN 更适合这项任务,但我有我的理由)。序列的长度为 128,每个元素是一个包含 2 个条目的向量,因此每个批次应该是 shape(batch_size, 128, 2)
并且目标是序列中的下一步,所以目标张量应该是 shape (batch_size, 1, 2)
。
网络架构是这样的:
model = Sequential()
model.add(Dense(50, batch_input_shape=(None, 128, 2), kernel_initializer="he_normal" ,activation="relu"))
model.add(Dense(20, kernel_initializer="he_normal", activation="relu"))
model.add(Dense(5, kernel_initializer="he_normal", activation="relu"))
model.add(Dense(2))
但试图训练我得到以下错误:
ValueError: Error when checking target: expected dense_4 to have shape (128, 2) but got array with shape (1, 2)
我尝试过以下变体:
model.add(Dense(50, input_shape=(128, 2), kernel_initializer="he_normal" ,activation="relu"))
但得到同样的错误。
解决方案
如果您查看model.summary()
输出,您会发现问题所在:
Layer (type) Output Shape Param #
=================================================================
dense_13 (Dense) (None, 128, 50) 150
_________________________________________________________________
dense_14 (Dense) (None, 128, 20) 1020
_________________________________________________________________
dense_15 (Dense) (None, 128, 5) 105
_________________________________________________________________
dense_16 (Dense) (None, 128, 2) 12
=================================================================
Total params: 1,287
Trainable params: 1,287
Non-trainable params: 0
_________________________________________________________________
如您所见,模型的输出与您预期的(None, 128,2)
不同(None, 1, 2)
(或)。(None, 2)
因此,您可能知道也可能不知道Dense 层应用在其输入数组的最后一个轴上,因此,正如您在上面看到的,时间轴和维度会一直保留到最后。
如何解决这个问题?您提到您不想使用 RNN 层,因此您有两个选择:您需要Flatten
在模型中的某处使用层,或者您也可以使用一些 Conv1D + Pooling1D 层,甚至是 GlobalPooling 层。例如(这些只是为了演示,你可以做不同的):
使用Flatten
层
model = models.Sequential()
model.add(Dense(50, batch_input_shape=(None, 128, 2), kernel_initializer="he_normal" ,activation="relu"))
model.add(Dense(20, kernel_initializer="he_normal", activation="relu"))
model.add(Dense(5, kernel_initializer="he_normal", activation="relu"))
model.add(Flatten())
model.add(Dense(2))
model.summary()
型号总结:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_17 (Dense) (None, 128, 50) 150
_________________________________________________________________
dense_18 (Dense) (None, 128, 20) 1020
_________________________________________________________________
dense_19 (Dense) (None, 128, 5) 105
_________________________________________________________________
flatten_1 (Flatten) (None, 640) 0
_________________________________________________________________
dense_20 (Dense) (None, 2) 1282
=================================================================
Total params: 2,557
Trainable params: 2,557
Non-trainable params: 0
_________________________________________________________________
使用GlobalAveragePooling1D
层
model = models.Sequential()
model.add(Dense(50, batch_input_shape=(None, 128, 2), kernel_initializer="he_normal" ,activation="relu"))
model.add(Dense(20, kernel_initializer="he_normal", activation="relu"))
model.add(GlobalAveragePooling1D())
model.add(Dense(5, kernel_initializer="he_normal", activation="relu"))
model.add(Dense(2))
model.summary()
型号总结:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_21 (Dense) (None, 128, 50) 150
_________________________________________________________________
dense_22 (Dense) (None, 128, 20) 1020
_________________________________________________________________
global_average_pooling1d_2 ( (None, 20) 0
_________________________________________________________________
dense_23 (Dense) (None, 5) 105
_________________________________________________________________
dense_24 (Dense) (None, 2) 12
=================================================================
Total params: 1,287
Trainable params: 1,287
Non-trainable params: 0
_________________________________________________________________
请注意,在上述两种情况下,您都需要将标签(即目标)数组重塑为(n_samples, 2)
(或者您可能希望Reshape
在最后使用一个图层)。
推荐阅读
- ios - UIView:如何通过程序更新XIB中的坐标
- google-apps-script - Google Apps 脚本错误“异常:无效参数:选项。应该是类型:地图”
- android - 在 android 中生成带有预告片、前导和分隔符的 EAN-13
- ubuntu - 在 Ubuntu 20.04 上编译和运行 Xrotor
- javascript - 希望在 github 页面上创建一个基于降价的博客
- laravel - 为什么我不能在黄昏的测试用例中注入一个类
- r - 如何删除与 R 中超过 1 个模式匹配的多行?
- c# - 多个 keydown 旋转会导致不需要的角度
- javascript - 是否有我可以嵌入的具有导航/方向的地图?| HTML
- firebase - Firebase 存储映像:哪种方式更便宜:将 base64 字符串放入 Firestore 或将映像存储在 Storage 中?