python - 在 Keras 中使用迁移学习训练 CNN - 图像输入不起作用,但矢量输入起作用
问题描述
我正在尝试在 Keras 中进行迁移学习。我设置了一个 ResNet50 网络设置为不可训练一些额外的层:
# Image input
model = Sequential()
model.add(ResNet50(include_top=False, pooling='avg')) # output is 2048
model.add(Dropout(0.05))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.15))
model.add(Dense(512, activation='relu'))
model.add(Dense(7, activation='softmax'))
model.layers[0].trainable = False
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
然后我创建输入数据:x_batch
使用 ResNet50preprocess_input
函数,以及一个热编码标签y_batch
,并按如下方式进行拟合:
model.fit(x_batch,
y_batch,
epochs=nb_epochs,
batch_size=64,
shuffle=True,
validation_split=0.2,
callbacks=[lrate])
十个左右的 epoch 后训练准确率接近 100%,但验证准确率实际上从 50% 左右下降到 30%,验证损失稳步增加。
但是,如果我改为创建一个仅包含最后一层的网络:
# Vector input
model2 = Sequential()
model2.add(Dropout(0.05, input_shape=(2048,)))
model2.add(Dense(512, activation='relu'))
model2.add(Dropout(0.15))
model2.add(Dense(512, activation='relu'))
model2.add(Dense(7, activation='softmax'))
model2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model2.summary()
并输入 ResNet50 预测的输出:
resnet = ResNet50(include_top=False, pooling='avg')
x_batch = resnet.predict(x_batch)
然后验证准确率达到 85% 左右……这是怎么回事?为什么图像输入法不起作用?
更新:
这个问题真的很奇怪。如果我将 ResNet50 更改为 VGG19,它似乎可以正常工作。
解决方案
经过大量谷歌搜索后,我发现问题与 ResNet 中的批量标准化层有关。VGGNet 中没有批量归一化层,这就是它适用于该拓扑的原因。
这里有一个在 Keras 中修复此问题的拉取请求,其中更详细地解释了:
假设我们使用 Keras 的预训练 CNN 之一,并且我们想要对其进行微调。不幸的是,我们无法保证 BN 层内的新数据集的均值和方差与原始数据集的均值和方差相似。因此,如果我们微调顶层,它们的权重将调整为新数据集的均值/方差。然而,在推理过程中,顶层将接收使用原始数据集的均值/方差进行缩放的数据。这种差异会导致准确性降低。
这意味着 BN 层正在根据训练数据进行调整,但是在执行验证时,将使用 BN 层的原始参数。据我所知,解决方法是允许冻结的 BN 层使用来自训练的更新均值和方差。
一种解决方法是预先计算 ResNet 输出。事实上,这大大减少了训练时间,因为我们没有重复计算的那部分。
推荐阅读
- java - 处理两个不同文件扩展名的正则表达式
- composer-php - 我可以清理/删除本地 Composer 目录吗?
- javascript - Typescript 或 babel 插件将 ES6 类转换为使用作用域变量作为私有属性模式的函数
- java - Java:在没有 Passay 和 VT-Password 库的情况下验证键盘序列规则
- symfony - 在 Symfony 中的 beforeSend 函数中获取 rootDir 以拦截 Sentry 事件
- php - 带有变量的 Laravel 视图,缺少路由所需的参数
- asp.net-core - 如何在 ASP.NET Core 3.1 中创建动态路由?
- nginx - Nginx 将子域 foo 重定向到 www.foo 到 www.www.foo 等
- python - Visual Studio 代码 - 未使用的导入警告 - Python
- mysql - 如何从 Sequelize 中的包含中引用父模型列?