首页 > 解决方案 > tf.keras model.fit():相同数据上的 train loss 和 val loss 之间的巨大差异

问题描述

TensorFlow 2.1 版

请参阅 colab 笔记本以重现该问题:https ://drive.google.com/file/d/1Fvc6G_9v5mek015cai7qYT6HoY-fLkzk/view?usp=sharing

当训练损失下降时,val_loss 不会改变,尽管这是完全相同的数据。

训练 2 个样本,验证 2 个样本
Epoch 1/30 2/2 [=============================] - 3s - 2s/样本 - 损失:0.4630 - val_loss:302.4763
Epoch 2/30 2/2 [============================= ] -1s - 457ms/sample - loss: 0.8565 - val_loss: 496.9578
Epoch 3/30 2/2 [========================== ===] - 1s - 457ms/样本 - 损失:0.7886 - val_loss:1050.9148
Epoch 4/30 2/2 [======================= ======] - 1s - 450ms/样本 - 损失:0.1080 - val_loss:744.4895
Epoch 5/30 2/2 [==================== =========] - 1s - 474ms/样本 - 损失:0.1144 - val_loss:1353.2678
Epoch 6/30 2/2 [================== ============] - 1s - 465ms/样本 - 损失:0.0402 - val_loss:3237.9683
纪元 7/30 2/2 [===============================] - 1s - 465ms/样本 - 损失:0.0635 - val_loss: 3946.7822
Epoch 8/30 2/2 [==============================] - 1s - 470ms/sample - 损失: 0.0355 - val_loss: 4054.5461
纪元 9/30 2/2 [==============================] - 1s - 462ms/样本 - 损失:0.0345 - val_loss:4991.5400

这怎么可能?代码非常简单:

ResNet18, preprocess_input = Classifiers.get('resnet18')
base_model = ResNet18(input_shape=(180, 320, 3), weights=None, include_top=False)
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(8)(x)
model = tf.keras.models.Model(inputs=base_model.input, outputs=output)
model.compile(optimizer='adam', loss='mse')
data = np.random.rand(2, 180, 320, 3)
labels = np.random.rand(2, 8)

model.fit(data, labels, validation_data=(data,labels), batch_size=2, epochs=30)

keras 和批处理规范化存在已知问题(参见例如 keras-team/keras#6977)。这可能是相关的,但我没有直接看到如何。我必须进行哪些更改才能使其按预期工作?这是包含的包https://github.com/qubvel/classification_models中的东西还是在哪里解决它?

编辑:从 TF 2.0 开始,批量标准化的行为发生了变化,因此其他问题可能不相关,请参阅https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

标签: tensorflowkerastensorflow2.0

解决方案


您没有将数据预处理到模型预期的范围。该模型包含BatchNormalization,因此您不能使用不在正确范围内的数据。

您需要preprocess_input在所有数据中使用,当然,您的初始范围应该是 0 到 255,因为它适用于图像。

理想情况下,您应该使用实际图像,因此您有类似的分布。


推荐阅读