python - 如何在 Pytorch 中打印调整学习率?
问题描述
当我在我的 PPO 算法中使用 torch.optim.Adam 和指数衰减_lr 时:
self.optimizer = torch.optim.Adam([
{'params': self.policy.actor.parameters(), 'lr': lr_actor},
{'params': self.policy.critic.parameters(), 'lr': lr_critic}
])
self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, self.GAMMA)
初始 lr=0.1,GAMMA=0.9。
然后我在我的时代动态打印 lr:
if time_step % update_timestep == 0:
ppo_agent.update()
print(f'__________start update_______________')
print(ppo_agent.optimizer.state_dict()['param_groups'][0]['lr'])
但是,这有问题,错误是:
File "D:\Anaconda\lib\site-packages\torch\distributions\beta.py", line 36, in __init__
self._dirichlet = Dirichlet(concentration1_concentration0, validate_args=validate_args)
File "D:\Anaconda\lib\site-packages\torch\distributions\dirichlet.py", line 52, in __init__
super(Dirichlet, self).__init__(batch_shape, event_shape, validate_args=validate_args)
File "D:\Anaconda\lib\site-packages\torch\distributions\distribution.py", line 53, in __init__
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter concentration has invalid values
然后,如果我删除“print()”语句,它确实工作得很好!所以,这让我非常困扰。
解决方案
您可以像这样获得学习率:
self.optimizer.param_groups[0]["lr"]
推荐阅读
- python - python模块的安装方式
- reactjs - 反应原生裁剪图像
- facebook - 如何使用 Facebook Marketing Api 获取与页面相关的所有活动?
- r - group_by 函数不能与另一个 group_by 一起使用
- powercli - vNIC 禁用 CSV 输出
- azure-logic-apps - 使用 HTTP GET 请求将遥测数据发送到 Azure IoT Central 设备
- html - HTML / FLASK - 每行带有复选框的动态表
- python - 将单个对象分配给每个 Python 多处理器线程
- c++ - 如何将时间字符串(M:SS)转换为浮点数
- laravel-5 - 不从 laravel 背包 4.0 中删除图像