python - Keras 自定义优化器批量更改参数
问题描述
我想自定义我自己的优化器,它将在 keras 中每批结束时改变学习率。首先,我构建了一个自定义回调:
class custom_callback(Callback):
def __init__(self,lr):
super(op_callback, self).__init__()
self.lr=lr
def on_batch_end(self,batch,logs={}):
sgd = SGD(lr=batch*self.lr)
self.model.compile(optimizer=sgd,loss='categorical_crossentropy',metrics=['accuracy'])
然后,我从doc复制 SGD 优化器代码。因为我想确保学习率发生变化,所以我在get_update
函数中打印学习率。
def get_updates(self, loss, params):
print(self.lr)
...
但它只打印一次学习率。我发现该get_update
函数只会在构建计算图的开始时被调用。但是我仍然不明白为什么即使我重新初始化 SGD 实例它也不会打印任何东西。如何在优化器中更改批次结束时的参数?提前致谢。
解决方案
查看源代码LearningRateScheduler
似乎是实现您想要的最小方法如下(它没有检查get_update
调用频率,我什至不确定它是否应该在每个批次上执行,无论如何这个回调肯定确实调整了学习率):
from keras import backend as K
from keras.callbacks import Callback
class BatchLearningRateScheduler(Callback):
def __init__(self, lr):
super().__init__()
self.lr = lr
def on_batch_end(self, batch, logs=None):
lr = batch * self.lr
K.set_value(self.model.optimizer.lr, lr)
推荐阅读
- python - TensorFlow:如何将张量的行与具有相同第一个元素的张量的第二个元素相加?
- python - 如何使用 Electron 在后端使用 Python 进行图像处理应用程序?
- python - python如何通过特定的n个字符从txt文件中读取值
- flutter - 颤振中枚举的替代方案
- c# - 访问数据库连接字符串
- java - Android Studio 问题:连接失败:ECONNREFUSED(连接被拒绝)
- html - 内容溢出时页眉和页脚占据页面的全宽
- excel - 使用 VBA 代码评估和比较单元格值
- reactjs - 将数据从一个组件传递到另一个组件 - React & Redux
- gnuplot - 尝试从 2 个数据文件中绘制 3d 形状时,gnuplot“未完全指定先前的参数函数”