python - optimizer.param_groups[0]['lr'] 和 scheduler.get_lr()[0] 之间的区别
问题描述
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import AlexNet
import matplotlib.pyplot as plt
model = AlexNet(num_classes=2)
optimizer = optim.SGD(params=model.parameters(), lr=0.05)
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
plt.figure()
x = list(range(100))
y = []
for epoch in range(100):
scheduler.step()
lr = scheduler.get_lr()
y.append(optimizer.param_groups[0]['lr'])
#y.append(scheduler.get_lr()[0])
plt.plot(x, y)
plt.savefig("a.jpg")
如果我使用:
y.append(optimizer.param_groups[0]['lr'])
如果我使用:
y.append(scheduler.get_lr()[0])
为什么 optimizer.param_groups[0]['lr'] 和 scheduler.get_lr()[0] 不同?我的网络的真实学习率是多少?
我的 pytorch 版本是 1.9.1
我检查了 StepLR 的源代码:
class StepLR(_LRScheduler):
"""Decays the learning rate of each parameter group by gamma every
step_size epochs. Notice that such decay can happen simultaneously with
other changes to the learning rate from outside this scheduler. When
last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ...
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs]
为什么它是 ruturns group['lr'] * self.gamma?似乎 optimizer.param_groups[0]['lr'] 和 scheduler.get_lr()[0] 在以前的版本中是相同的
解决方案
推荐阅读
- xslt - 根据 XSLT 中的节点值生成新组
- django - 当 django 项目加载到 heroku 网站时,图片 url 不显示我上传的图片
- reactjs - 将内容放入饼图中,并在 chartjs 中进行设计
- javascript - 如何使轮播自动更改而不是在 ReactJs 上使用循环单击 onclick
- django - 如何在 Django REST 框架中加快发送电子邮件的速度?
- scala - 如何在 sqlite3 中使用 quill?
- go - 在 Go 中使用 Uber-Zap 记录器将指定的日志发送到 Kafka 接收器
- java - 通过 JDBC 语句执行 DDL“更改表学生添加约束 UK_fe0i52si7ybu0wjedj6motiim 唯一(电子邮件)”时出错
- javascript - 请解释这个 setTimeOut() 函数的输出?
- python - 如何在 Python 中绘制带有异常值和四分位数的箱线图