tensorflow - tf.keras model.fit():相同数据上的 train loss 和 val loss 之间的巨大差异
问题描述
TensorFlow 2.1 版
请参阅 colab 笔记本以重现该问题:https ://drive.google.com/file/d/1Fvc6G_9v5mek015cai7qYT6HoY-fLkzk/view?usp=sharing
当训练损失下降时,val_loss 不会改变,尽管这是完全相同的数据。
训练 2 个样本,验证 2 个样本
Epoch 1/30 2/2 [=============================] - 3s - 2s/样本 - 损失:0.4630 - val_loss:302.4763
Epoch 2/30 2/2 [============================= ] -1s - 457ms/sample - loss: 0.8565 - val_loss: 496.9578
Epoch 3/30 2/2 [========================== ===] - 1s - 457ms/样本 - 损失:0.7886 - val_loss:1050.9148
Epoch 4/30 2/2 [======================= ======] - 1s - 450ms/样本 - 损失:0.1080 - val_loss:744.4895
Epoch 5/30 2/2 [==================== =========] - 1s - 474ms/样本 - 损失:0.1144 - val_loss:1353.2678
Epoch 6/30 2/2 [================== ============] - 1s - 465ms/样本 - 损失:0.0402 - val_loss:3237.9683
纪元 7/30 2/2 [===============================] - 1s - 465ms/样本 - 损失:0.0635 - val_loss: 3946.7822
Epoch 8/30 2/2 [==============================] - 1s - 470ms/sample - 损失: 0.0355 - val_loss: 4054.5461
纪元 9/30 2/2 [==============================] - 1s - 462ms/样本 - 损失:0.0345 - val_loss:4991.5400
这怎么可能?代码非常简单:
ResNet18, preprocess_input = Classifiers.get('resnet18')
base_model = ResNet18(input_shape=(180, 320, 3), weights=None, include_top=False)
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(8)(x)
model = tf.keras.models.Model(inputs=base_model.input, outputs=output)
model.compile(optimizer='adam', loss='mse')
data = np.random.rand(2, 180, 320, 3)
labels = np.random.rand(2, 8)
model.fit(data, labels, validation_data=(data,labels), batch_size=2, epochs=30)
keras 和批处理规范化存在已知问题(参见例如 keras-team/keras#6977)。这可能是相关的,但我没有直接看到如何。我必须进行哪些更改才能使其按预期工作?这是包含的包https://github.com/qubvel/classification_models中的东西还是在哪里解决它?
编辑:从 TF 2.0 开始,批量标准化的行为发生了变化,因此其他问题可能不相关,请参阅https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
解决方案
您没有将数据预处理到模型预期的范围。该模型包含BatchNormalization
,因此您不能使用不在正确范围内的数据。
您需要preprocess_input
在所有数据中使用,当然,您的初始范围应该是 0 到 255,因为它适用于图像。
理想情况下,您应该使用实际图像,因此您有类似的分布。
推荐阅读
- javascript - 拆分日期范围重叠成块(javascript)
- java - 我的应用程序我想将一些测试显示为 url 链接,因为当我们键入或复制粘贴 url 时,watspp 会显示
- xamarin.forms - 扫描条码时使用手电筒
- android - 如何获取特定字符串格式的日期完整时间
- c# - 设计视图模型并避免在控制器中使用 if else 语句并使用设计模式编写良好的业务逻辑 asp.net web api
- django - Django:如何在 sqlite3 数据库中添加/删除字段?
- mysql - 列出受相同抗生素影响的细菌对
- sql - 三个条件以防万一
- java - 算法实现java(algo)
- angular - 将@angular/pwa 添加到现有项目