python - 在 Keras 中使用 Adam 优化器恢复训练
问题描述
我的问题很简单,但我无法在网上找到明确的答案(到目前为止)。
在定义数量的训练时期后,我使用以下方法保存了使用 adam 优化器训练的 keras 模型的权重:
callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])
当我关闭 jupyter 后恢复训练时,我可以简单地使用:
model.load_weights(path)
继续训练。
由于 Adam 依赖于 epoch 数(例如在学习率衰减的情况下),我想知道在与以前相同的条件下恢复训练的最简单方法。
按照 ibarrond 的回答,我编写了一个小的自定义回调。
optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])
weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)
class optim_callback(tf.keras.callbacks.Callback):
'''Custom callback to save optimiser state'''
def on_epoch_end(self,epoch,logs=None):
optim_state = tf.keras.optimizers.Adam.get_config(optim)
with open(optim_state_pkl,'wb') as f_out:
pickle.dump(optim_state,f_out)
model.fit(X,y,callbacks=[weight_callback,optim_callback()])
当我恢复训练时:
model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:
optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)
我只想检查这是否正确。再次感谢!
附录:在进一步阅读 Adam 的默认Keras 实现和原始 Adam 论文后,我认为默认 Adam 不依赖于 epoch 数,而仅依赖于迭代次数。因此,这是不必要的。但是,对于希望跟踪其他优化器的任何人,该代码可能仍然有用。
解决方案
为了完美地捕获优化器的状态,您应该使用函数存储其配置get_config()
。此函数返回一个字典(包含选项),可以使用pickle
.
要重新启动该过程,只需d = pickle.load('my_saved_tfconf.txt')
使用配置检索字典,然后使用Keras Adam Optimizer 的功能生成您的 Adam Optimizerfrom_config(d)
。
推荐阅读
- pjsip - 如何处理 pjsip 消息?
- python - 在 IDLE 中的会话之间保存命令历史记录
- sql-server - SQL Server 中的实体框架 + 列的平均部分
- c++ - 在 C++ 中打开文件时出现分段错误
- javascript - 将 ltr 提供给 MDC 开关在 rtl MDC 抽屉内不起作用
- spring - 是在启动 Spring Boot 应用程序 2+ 时进行线程转储吗?我如何停止或禁用它?
- amazon-redshift - 如何找到授予 Redshift 权限的时间戳?
- javascript - 无论如何要在 Chrome 中请求更高的 localStorage 限制而不进行扩展(从 Javascript 或用户手动从 Chrome UI)?
- pandas - 根据 pandas 数据框的列名和值创建字典
- assembly - 汇编语言:为什么我的字符串打印在错误的地方