首页 > 解决方案 > 波动训练损失背后的直觉

问题描述

我正在尝试为 28x28x5 图像构建卷积自动编码器。以下是我的模型的摘要:


Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 28, 28, 16)        736
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 8)         1160
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 8)           0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 8)           584
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 8)           0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 4, 4, 8)           584
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 8, 8, 8)           0
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 8)           584
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 16, 16, 8)         0


_________________________________________________________________
conv2d_6 (Conv2D)            (None, 14, 14, 16)        1168
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 16)        0
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 5)         725
=================================================================
Total params: 5,541
Trainable params: 5,541
Non-trainable params: 0

在绘制每个时期的训练和测试误差时,我得到下图: 在此处输入图像描述 学习率从 0.01 开始,达到 0.01/750 = 0,000013333

我认为波动可能是因为学习率太高,所以我再次尝试使用 0.00001 到 0,0000002,结果是: 在此处输入图像描述

为什么训练损失总是波动这么大,而测试误差几乎是恒定的?这是正常行为吗?对我来说,他们似乎应该表现得相似。第一张图中两个损失的初始减少让我相信代码至少在做正确的事情,但其余的感觉是错误的。

下面是我的训练代码:

(trainX, testX, trainY, testY) = train_test_split(newData,
    newData, test_size=0.25)

# construct the image generator for data augmentation
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
    height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
    horizontal_flip=True, fill_mode="nearest")

model = AutoEncoder.build(width=28, height=28, depth=5)

# initialize our initial learning rate, # of epochs to train for,
# and batch size
INIT_LR = 0.001
EPOCHS = 50
BS = 16

opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)

model.compile(loss="mean_squared_error", optimizer = opt)

H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
    validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
    epochs=EPOCHS)

编辑:

训练集中图像层中的示例行:

[ 7.41672516e-03  3.14044952e-03 -1.39656067e-02 -1.18265152e-02
  6.34765625e-03  1.27620697e-02  1.49002075e-02 -6.48117065e-03
  1.00231171e-03 -2.20489502e-03  3.14044952e-03  1.00231171e-03
  1.00231171e-03  3.14044952e-03  1.00231171e-03  4.20951843e-03
 -1.13582611e-03  1.06239319e-02  6.34765625e-03  4.20951843e-03
  1.49002075e-02  4.20951843e-03 -6.67572021e-05 -6.67572021e-05
 -1.13582611e-03 -3.27396393e-03 -5.41210175e-03 -6.48117065e-03]

测试集:

[ 1.3831139e-02  1.2762070e-02  7.4167252e-03  2.0713806e-03
 -1.1826515e-02 -7.5502396e-03  1.0023117e-03  3.1404495e-03
  2.6660919e-02 -6.6757202e-05  6.3476562e-03  2.0713806e-03
 -7.5502396e-03 -1.1358261e-03  1.0023117e-03  1.0023117e-03
 -1.1358261e-03 -6.4811707e-03  5.2785873e-03 -4.3430328e-03
  1.0023117e-03  1.1693001e-02  2.3453712e-02  1.3831139e-02
  1.9177437e-02  1.5969276e-02  2.0713806e-03  2.0713806e-03]

因为它们都来自同一个输入数据集,所以我认为这不是问题

编辑 2:使用 450000 张图像进行训练会导致: 在此处输入图像描述

添加更多的训练数据似乎可以解决问题,尽管我仍然很奇怪为什么训练误差波动如此之大,而验证误差却没有。

编辑 3:

还增加了批量大小: 在此处输入图像描述

标签: pythontensorflowmachine-learningdata-scienceautoencoder

解决方案


我认为模型输出是一个常数。你能检查一下吗?您的验证集中有多少样本?如果模型输出恒定,则训练损失会波动(因为不同批次中的样本不同),而验证损失将保持恒定。也许尝试改变学习率/批量大小或添加更多数据(也可以添加到验证集)。


推荐阅读