python - 向 Conv2D 层添加 padding='same' 参数破坏了模型
问题描述
我创建了这个模型
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D, Input, Dense
from tensorflow.keras.layers import Reshape, Flatten
from tensorflow.keras import Model
def create_DeepCAPCHA(input_shape=(28,28,1),n_prediction=1,n_class=10,optimizer='adam',
show_summary=True):
inputs = Input(input_shape)
x = Conv2D(filters=32, kernel_size=3, activation='relu', padding='same')(inputs)
x = MaxPooling2D(pool_size=2)(x)
x = Conv2D(filters=48, kernel_size=3, activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=2)(x)
x = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=2)(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
x = Dense(units=n_prediction*n_class, activation='softmax')(x)
outputs = Reshape((n_prediction,n_class))(x)
model = Model(inputs, outputs)
model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics= ['accuracy'])
if show_summary:
model.summary()
return model
我在 MNIST 数据集上尝试了模型
import tensorflow as tf
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
inputs = x_train
outputs = tf.keras.utils.to_categorical(y_train, num_classes=10)
outputs = np.expand_dims(outputs,1)
model = create_DeepCAPCHA(input_shape=(28,28,1),n_prediction=1,n_class=10)
model.fit(inputs, outputs, epochs=10, validation_split=0.1)
但它未能收敛(停留在 10% 的准确率 => 与随机猜测相同)。然而,当我从 Conv2D 层中删除“padding='same'”参数时,它可以完美地工作:
def working_DeepCAPCHA(input_shape=(28,28,1),n_prediction=1,n_class=10,optimizer='adam',
show_summary=True):
inputs = Input(input_shape)
x = Conv2D(filters=32, kernel_size=3, activation='relu')(inputs)
x = MaxPooling2D(pool_size=2)(x)
x = Conv2D(filters=48, kernel_size=3, activation='relu')(x)
x = MaxPooling2D(pool_size=2)(x)
x = Conv2D(filters=64, kernel_size=3, activation='relu')(x)
x = MaxPooling2D(pool_size=2)(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
x = Dense(units=n_prediction*n_class, activation='softmax')(x)
outputs = Reshape((n_prediction,n_class))(x)
model = Model(inputs, outputs)
model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics= ['accuracy'])
if show_summary:
model.summary()
return model
任何人都知道这是什么问题?
解决方案
谢谢你的分享,这对我来说真的很有趣。所以我写了代码并测试了几个场景。请注意,我要说的只是我的客人,我不确定。
我从这些测试中得出的结论是,没有填充或valid
填充起作用,因为它会(1, 1, 64)
为最后一个卷积层生成输出形状。但是如果你将填充设置为same
它会产生(3, 3, 64)
,并且因为下一层是一个大的 Dense 层,它会将网络参数的数量乘以 9(我预计会以某种方式导致过度拟合),这似乎使网络来找到参数的好值。所以我尝试了一些不同的方法来减少最后一个卷积层的输出,(1, 1, 64)
如下所示:
- 再使用一个卷积层 + 最大池化
- 将最后一个 maxpooling 更改为 pool_size 4
- 对其中一个卷积层使用 2 的步幅
- 将最后一个卷积层的过滤器更改为 20
他们都运作良好。即使将密集单元从 512 更改为 64 也会有所帮助(请注意,即使现在你也可能会得到很差的结果,因为我猜初始化不好)。
然后我将最后一个 conv 层的形状更改为(2, 2, 64)
并降低了获得良好结果(超过 90% 准确度)的机会(很多时候我得到了 10% 的准确度)。
所以看起来很多参数会混淆模型。但是如果你想知道为什么网络没有过拟合,我没有答案。
推荐阅读
- python - python - 使用空格分隔符将字符串转换为int
- azure - Azure Spot 实例 | gpu-集群
- nm - nm 结果在符号名称中带有尾随数字
- leakcanary - LeakCanary 报告 0 个不同的泄漏,但经常转储 hprof 文件
- python - 函数 if True 返回组合列表
- sonarqube - SonarQube 8.6 - 禁用某些语言的质量配置文件
- javascript - 从外部 json 文件中检索数据
- string - [FLUTTER]:根据 ListTile 的高度剪切文本
- javascript - 如何修复 Coinbase Pro API 请求标头?
- c - HTTP/1.1 服务器实现