首页 > 解决方案 > 在简单的前馈网络中展平

问题描述

我正在研究 CIFAR10 数据集,并在 Keras 中遇到了这个示例,使用数据增强:

https://keras.io/examples/cifar10_cnn/

该示例使用 CNN。我只想实现一个简单的前馈网络,而不是 CNN。因此,为了让我的简单模型“工作”,我必须在输出层之前添加“model.Flatten()”,以保持数据形状的一致性。

但是,我只在 CNN 中看到过使用 Flatten()。

我相信它可以用于简单的前馈网络,但我错过了什么吗?

下面是我想与 keras 示例一起使用的模型代码。

model = Sequential()
model.add(Dense(layer_size, input_shape=x_train.shape[1:], activation = "relu")
model.add(Dense(128, activation = "relu"))      
model.add(Dense(64, activation = "relu"))
model.add(Flatten())
model.add(Dense(10, activation = "softmax"))
model.summary()

谢谢

标签: tensorflowmachine-learningkerasdeep-learning

解决方案


你应该Flatten输入:

model = Sequential()
model.add(Flatten(input_shape=x_train.shape[1:]))
model.add(Dense(layer_size,activation = "relu")
model.add(Dense(128, activation = "relu"))      
model.add(Dense(64, activation = "relu"))
model.add(Dense(10, activation = "softmax"))
model.summary()

Flatten将维度张量展平为n维度张1量. 例如,2x2灰度图像变为 1 维:

[[255, 127   ],
 [154,   123]]

变成

[255, 127, 154, 123]

这样,您的输入彩色图像(3 维 , [width, height, channels])也将变为 1 维并适合Dense图层。


推荐阅读