python - 加载模型后更改优化器或 lr 会产生奇怪的结果
问题描述
我正在使用带有 Tensorflow 后端的最新 Keras(Python 3.6)
我正在加载一个模型,当我上次训练它时,它的训练准确率约为 86%。
我使用的原始优化器是:
r_optimizer = optimizer=Adam(lr=0.0001, decay = .02)
model.compile(optimizer= r_optimizer,
loss='categorical_crossentropy', metrics = ['accuracy'])
如果我加载模型并继续训练而不重新编译,我的准确率将保持在 86% 左右(即使在 10 个左右的 epoch 之后)。所以我想尝试改变学习率或优化器。
如果我重新编译模型并尝试更改学习率或优化器,如下所示:
new_optimizer = optimizer=Adam(lr=0.001, decay = .02)
或者这个:
sgd = optimizers.SGD(lr= .0001)
然后编译:
model.compile(optimizer= new_optimizer ,
loss='categorical_crossentropy', metrics = ['accuracy'])
model.fit ....
准确率将重置为 15% - 20% 左右,而不是从 86% 左右开始,我的损失会高得多。即使我使用较小的学习率并重新编译,我仍然会从非常低的准确度开始。通过浏览互联网,似乎一些优化器(如 ADAM 或 RMSPROP)在重新编译后重置权重时遇到问题(目前找不到链接)
所以我做了一些挖掘并尝试在不重新编译的情况下重置我的优化器,如下所示:
model = load_model(load_path)
sgd = optimizers.SGD(lr=1.0) # very high for testing
model.optimizer = sgd #change optimizer
#fit for training
history =model.fit_generator(
train_gen,
steps_per_epoch = r_steps_per_epoch,
epochs = r_epochs,
validation_data=valid_gen,
validation_steps= np.ceil(len(valid_gen.filenames)/r_batch_size),
callbacks = callbacks,
shuffle= True,
verbose = 1)
然而,这些变化似乎并没有反映在我的训练中。尽管lr
大幅提高,但我仍然以同样的损失在 86% 左右挣扎。在每个时期,我看到的损失或准确性移动非常少。我预计损失会更加波动。这让我相信我在优化器和 lr 方面的改变并没有被模型实现。
知道我做错了什么吗?
解决方案
这是参考您在此处写的内容的部分答案:
通过浏览互联网,似乎一些优化器(如 ADAM 或 RMSPROP)在重新编译后重置权重时遇到问题(目前找不到链接)
自适应优化器(例如 ADAM RMSPROP、ADAGRAD、ADADELTA 以及它们的任何变体)依赖于先前的更新步骤来改进对模型权重的任何当前调整的方向和幅度。
正因为如此,他们采取的前几个步骤往往相对“糟糕”,因为他们用之前步骤中的信息“校准自己”。
当用于随机初始化时,这不是问题,但当用于预训练模型时,这几个初始步骤会严重降低模型,以至于几乎所有预训练的工作都会丢失。
更糟糕的是,现在训练不是从像 Xavier 初始化那样精心选择的随机初始化开始,而是从某个次优的起点开始,这可能会阻止模型收敛到如果它开始就会达到的局部最优来自一个好的随机初始化。
不幸的是,我不确定如何避免这种情况...也许使用一个优化器进行预训练->保存权重->替换优化器->恢复权重->训练几个时期,并希望新的自适应优化器学习“有用的历史” - >而不是从预训练模型的保存权重中恢复权重并且无需重新编译再次开始训练,现在使用更好的优化器“历史”。
请让我们知道这是否有效。
推荐阅读
- jupyter-notebook - 当 Jupyter Hub 中的 Python 文件很大时,无法添加目录扩展
- azure-data-factory - 在 .net 核心中使用 .NET SDK 创建 Azure 数据工厂?
- arrays - Hackerrank 中 numpy 的眼睛和身份功能。为什么测试用例失败了?
- uwp - CalendarDatePicker,TodayDate 在 UWP 中没有突出显示?
- python - 如果该行与DataFrame中的其他行有一定关系,如何删除行?
- node.js - 如何使用 node.js 和 pm2 将 verdaccio 发布到互联网?
- sql - sql中存储过程的嵌套循环不起作用
- c++ - 为什么我在提交leetcode时会报错,但在IDE中使用相同的代码?
- retrofit2 - 带有 Retrofit 2 的 kotlinx 序列化中的构建错误
- python - 我刚刚开始用 Python 制作一个基本的交互式计算器,但不知道为什么它不接受给定的输入