tensorflow - Keras:根据 ModelCheckpoint-callback 的最佳模型在训练集上产生的损失与训练时显示的最佳时期损失不同
问题描述
我正在尝试用 Python 中的 TensorFlow 后端训练一个非常简单的 Keras 模型。
我知道训练时控制台中显示的历元损失是为了提高效率而“动态”计算的,因此不一定是中间模型的实际损失。但据我了解,如果每个 epoch 只包含一个批次,即整个训练集,它们实际上应该是这样。这种期望的原因是,在这种情况下,模型的权重仅在每个 epoch 结束时更新一次,这意味着在计算 epoch 的损失时模型不会改变。
不幸的是,即使我将批量大小设置为训练集的大小,最佳时期的损失也不同于根据 ModelCheckpoint 回调的最佳模型损失。
有人可以向我解释这种行为吗?ModelCheckpoint-callback 是否也只能在某种“动态”中计算中间模型的损失?
这是我的代码,其中bestEpochLoss
和bestModelLoss
永远不一样:
import numpy
import keras
#Create train data
trainInput = numpy.array([4,3,1,0,2])
trainOutput = numpy.array([0,2,2,0,1])
#Create and train model
model = keras.Sequential([
keras.layers.Dense(200, input_shape=(1,), activation='tanh'),
keras.layers.Dense(1, activation='linear')
])
model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(lr=0.1))
callbacks = [keras.callbacks.ModelCheckpoint(filepath='model.hdf5', monitor='loss', verbose=1, save_best_only=True)]
history = model.fit(trainInput, trainOutput, callbacks=callbacks, epochs=20, batch_size=len(trainInput))
#Evaluate best training epoch's loss vs best model's loss
bestEpochLoss = numpy.min(history.history['loss'])
bestModel = keras.models.load_model('model.hdf5')
bestModelLoss = bestModel.evaluate(trainInput, trainOutput)
print('Best training epoch\'s loss: ' + str(bestEpochLoss))
print('Best model\'s loss: ' + str(bestModelLoss))
解决方案
The reason for that expectation is that in that case the model's weights are only updated once at the end of each epoch which means that the model does not change while an epoch's loss is being calculated.
Usually this is not true. Weights are updated depending on which variant of gradient descent is used. In many cases this is batch gradient descent, so you will have weight updates every batch.
推荐阅读
- c# - ASP.NET WPF - 带有圆角半径的图像溢出边框
- node.js - 尝试通过服务器端渲染渲染页面时 jsdom 中的内存泄漏
- docker - docker 与 freebsd 10 兼容吗?
- javascript - 当字段为空并使用 Ajax 发布时,checkValidity() 不显示任何 html5 错误通知
- typescript - 对象属性路径的 TypeScript 类型定义
- gitlab - 将源代码和历史从一个 GitLab 项目复制到另一个
- flutter - 如何使bottomNavigationBar缺口透明
- python - 带有 Tkintertable 的 Python 中包含一些空行的 CSV 文件
- r - 在向量中写入与矩阵中的行一样多的 1
- php - Netatmo Api - 如果温度 X 则需要 Hlep,否则