python - 用 CNN 代替 MLP
问题描述
我已经建立了一个具有以下架构的神经网络:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)
print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)
(1901, 456, 3) (476, 456, 3) (1901, 3, 3) (476, 3, 3)
model = Sequential()
model.add(Flatten(input_shape=(456,3)))
model.add(Dense(64, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(3 * 3))
model.add(Reshape((3, 3)))
model.compile('adam', 'mse')
history = model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=100)
现在我想用一个类似的 CNN 替换这个架构,它做同样的事情;但是在尝试实现这一点时,我总是遇到不同层的尺寸问题。而我的错误总是这样
ValueError:检查输入时出错:预期 conv2d_3_input 有 4 个维度,但得到了形状为 (x, x, x) 的数组
数据集保持不变,只是 NN 架构发生了变化,这是我的第一种方法:
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=(1901,456,3)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))
有人可以帮我将我的第一个 NN 替换为 CNN 吗?
解决方案
您的网络定义明确,您遇到的错误是在fit
操作期间。为什么会这样。
Conv2D 正在寻找具有4D
形状的数据,如您在此处看到的:doc
X_train 形状必须是(samples, channels, rows, cols)
当你给 时input_shape=(1901,456,3)
,你不必指定样本的数量。
但在fit
操作过程中,您需要有一个形状为 的数据(samples, channels, rows, cols)
。
现在你看到你有一个问题。为什么你的 X_train 是这样的形状,看起来你只有一个图像。您可以通过使用以下方式对其进行重塑来喂养它:
X_train = X_train.reshape((1, 1901, 456, 3))
但这似乎很奇怪,您只向网络提供了一张图像。
编辑:在澄清评论后,conv1D 在这种情况下会更好,这里是如何做到的:
model = Sequential()
model.add(Conv1D(32, kernel_size=3,
activation='relu',
input_shape=(456,3)))
model.add(Conv1D(64, 3, activation='relu'))
model.add(MaxPooling1D(pool_size=2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3 * 3, activation='softmax'))
model.add(Reshape((3, 3))
推荐阅读
- python - python - 如何在没有循环的情况下从List连接两个数据帧(具有相同的名称和相同的列)?
- bash - 如何定位特定的 tmux 会话
- java - Spring JPA 实体类是否可以包含不在数据库表中的非数据库字段
- c# - 尝试在 C# 中执行存储过程参数实体框架 EDMX 模型时出错
- php - 获取产品变体的 SKU
- .net - WPF .NET Clickonce 部署安装位置
- python - Python访问全局模块在导入功能中失败
- wordpress - 在“附加信息”选项卡中显示变量产品的自定义字段
- typescript - 为什么“allowSyntheticDefaultImports”和“esModuleInterop”不起作用?
- c++ - 将字符串指针传递给在 C++ 和 Xcode 11.1 中不同线程上运行的函数