python - 波动训练损失背后的直觉
问题描述
我正在尝试为 28x28x5 图像构建卷积自动编码器。以下是我的模型的摘要:
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 28, 28, 16) 736
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 8) 1160
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 8) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 7, 8) 584
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 8) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 4, 4, 8) 584
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 8, 8, 8) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 8, 8, 8) 584
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 16, 16, 8) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 14, 14, 16) 1168
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 16) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 28, 28, 5) 725
=================================================================
Total params: 5,541
Trainable params: 5,541
Non-trainable params: 0
在绘制每个时期的训练和测试误差时,我得到下图: 学习率从 0.01 开始,达到 0.01/750 = 0,000013333
我认为波动可能是因为学习率太高,所以我再次尝试使用 0.00001 到 0,0000002,结果是:
为什么训练损失总是波动这么大,而测试误差几乎是恒定的?这是正常行为吗?对我来说,他们似乎应该表现得相似。第一张图中两个损失的初始减少让我相信代码至少在做正确的事情,但其余的感觉是错误的。
下面是我的训练代码:
(trainX, testX, trainY, testY) = train_test_split(newData,
newData, test_size=0.25)
# construct the image generator for data augmentation
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode="nearest")
model = AutoEncoder.build(width=28, height=28, depth=5)
# initialize our initial learning rate, # of epochs to train for,
# and batch size
INIT_LR = 0.001
EPOCHS = 50
BS = 16
opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="mean_squared_error", optimizer = opt)
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS)
编辑:
训练集中图像层中的示例行:
[ 7.41672516e-03 3.14044952e-03 -1.39656067e-02 -1.18265152e-02
6.34765625e-03 1.27620697e-02 1.49002075e-02 -6.48117065e-03
1.00231171e-03 -2.20489502e-03 3.14044952e-03 1.00231171e-03
1.00231171e-03 3.14044952e-03 1.00231171e-03 4.20951843e-03
-1.13582611e-03 1.06239319e-02 6.34765625e-03 4.20951843e-03
1.49002075e-02 4.20951843e-03 -6.67572021e-05 -6.67572021e-05
-1.13582611e-03 -3.27396393e-03 -5.41210175e-03 -6.48117065e-03]
测试集:
[ 1.3831139e-02 1.2762070e-02 7.4167252e-03 2.0713806e-03
-1.1826515e-02 -7.5502396e-03 1.0023117e-03 3.1404495e-03
2.6660919e-02 -6.6757202e-05 6.3476562e-03 2.0713806e-03
-7.5502396e-03 -1.1358261e-03 1.0023117e-03 1.0023117e-03
-1.1358261e-03 -6.4811707e-03 5.2785873e-03 -4.3430328e-03
1.0023117e-03 1.1693001e-02 2.3453712e-02 1.3831139e-02
1.9177437e-02 1.5969276e-02 2.0713806e-03 2.0713806e-03]
因为它们都来自同一个输入数据集,所以我认为这不是问题
添加更多的训练数据似乎可以解决问题,尽管我仍然很奇怪为什么训练误差波动如此之大,而验证误差却没有。
编辑 3:
解决方案
我认为模型输出是一个常数。你能检查一下吗?您的验证集中有多少样本?如果模型输出恒定,则训练损失会波动(因为不同批次中的样本不同),而验证损失将保持恒定。也许尝试改变学习率/批量大小或添加更多数据(也可以添加到验证集)。
推荐阅读
- python - 图像处理中的内存管理
- matlab - 朴素的高斯消除 Matlab 问题
- amazon-web-services - 销毁 AWS-CDK 中的堆栈时不要删除现有资源
- ios - 每次重新打开浏览器时都忘记了 Django 会话 cookie - 移动 Safari(iphone、ipad)
- java - Spring Boot:创建自定义 Jsp 标记 - 无法找到 taglib
- html - 如何修复我的 GitHub Pages 站点上的 Sass?
- gitlab - gitlab 的标准安装产生 404
- deep-linking - 如何在pdf文档中创建一个非常深的链接
- haskell - c 头文件更改后 Cabal 不重建项目
- android - React Native Android:如何检查项目是否使用来自 Homebrew 或 gradle-wrapper.properties 的 Gradle 版本?