python - 为什么在 Keras 中 Adam.iterations 总是设置为 0?
问题描述
我目前正在尝试通过 keras/tensorflow 构建神经网络并解决一些示例问题。目前,我尝试了解如何通过 model.save()/.load() 正确保存和加载当前模型。我希望,如果一切设置正确,加载预训练模型并继续训练不应该破坏我之前的准确性,而只是从我离开的地方继续。
然而,事实并非如此。加载模型后,我的准确度开始大幅波动,需要一段时间才能真正恢复到之前的准确度:
第一次运行
继续运行
在挖掘了各种可能的解释(它们都不适用于我的发现)之后,我想我找到了原因:
我使用 tf.keras.optimizers.Adam 进行权重优化并检查了它的初始化程序
def __init__(self, [...], **kwargs):
super(Adam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
[...]
def get_config(self):
config = {
'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad
}
似乎“迭代”计数器总是重置为 0,并且当整个模型保存为它不是配置字典的一部分时,它的当前值既不存储也不加载。这似乎与 model.save 保存“优化器的状态,允许在您停止的地方恢复训练”的说法相矛盾。(https://keras.io/getting-started/faq/)。由于迭代计数器是控制 Adam 算法中学习率的指数“辍学”的计数器
1. / (1. + self.decay * math_ops.cast(self.iterations,
K.dtype(self.decay))))
我的模型将始终以初始“大”学习率重新启动,即使我将 model.fit() 中的“initial_epoch”参数设置为保存模型的实际纪元数(参见上面上传的图像)。
所以我的问题是:
- 这是预期的行为吗?
- 如果是这样,这与 keras 常见问题解答中引用的声明是否一致,即 model.save() “在您停止的地方恢复训练”?
有没有办法在不编写我自己的优化器的情况下实际保存和恢复包括迭代计数器在内的 Adam 优化器(我已经发现这是一个可能的解决方案,但我想知道是否真的没有更简单的方法)
编辑 我找到了原因/解决方案:我在 load_model 之后调用了 model.compile ,这会在保持权重的同时重置优化器(另请参阅是否 model.compile() 初始化 Keras 中的所有权重和偏差(tensorflow 后端)?)
解决方案
该iterations
值已恢复,如下面的代码片段所示。
model.save('dense_adam_keras.h5')
mdl = load_model('dense_adam_keras.h5')
print('iterations is ', K.get_session().run(mdl.optimizer.iterations))
iterations is 46
当调用 ' load_model
' 时,deserialize
会调用方法来创建优化器对象,然后调用set_weights
方法从保存的权重中恢复优化器状态。
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/optimizers.py
推荐阅读
- flutter - 高效地在 ListView 中搜索
- lambda - 标记“::”上的语法错误,groupingBy 时的 AssignmentOperator 无效
- javascript - 在页面滚动时突出显示当前导航
- angular - Angular:在一个组件中调用函数,在另一个组件上调用事件
- macos - 沙盒应用程序中的 AppleScript `activate` 不会将窗口带到前台
- javascript - 使用 laravel、vue.js 和 axios 删除猫头鹰幻灯片
- python - 如何在python中获取列表的多个元素的索引?
- jasper-reports - 使用 Jaspersoft Studio 设计的报表在报表服务器上显示空白内容
- javascript - Nodemailer 在实时服务器上失败
- swift - 将 UIPanGesture 与 SpriteKit 一起使用