deep-learning - 了解何时在 Pytorch 中使用 python 列表
问题描述
基本上正如这个线程在这里讨论的那样,你不能使用 python list 来包装你的子模块(例如你的层);否则,Pytorch 不会更新列表中子模块的参数。相反,您应该使用nn.ModuleList
包装子模块以确保它们的参数将被更新。现在我还看到了如下代码,作者使用 python list 计算损失,然后进行loss.backward()
更新(在强化算法中)。这是代码:
policy_loss = []
for log_prob in self.controller.log_probability_slected_action_list:
policy_loss.append(- log_prob * (average_reward - b))
self.optimizer.zero_grad()
final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
final_policy_loss.backward()
self.optimizer.step()
为什么使用这种格式的列表可以更新模块的参数,但第一种情况不起作用?我现在很困惑。如果我更改前面的代码policy_loss = nn.ModuleList([])
,它会抛出一个异常,说张量浮点不是子模块。
解决方案
您误解了Module
s 是什么。AModule
存储参数并定义前向传递的实现。
您可以使用张量和参数执行任意计算,从而产生其他新张量。Modules
不需要知道那些张量。您还可以在 Python 列表中存储张量列表。调用backward
时,它需要在标量张量上,因此是串联的总和。这些张量是损失而不是参数,因此它们不应该是 a 的属性,Module
也不应该包含在 a 中ModuleList
。
推荐阅读
- nlp - Spacy:词向量使用什么算法?
- c# - Unity C# 从同一个 GameObject 上的另一个脚本中获取变量
- spring-boot - 如何确保只有一个消费者实际使用已发布的消息?
- matlab - 在 MATLAB 中对指定轴进行插值
- c# - InstanceContextMode = InstanceContextMode.PerSession 不工作
- http - Ionic http 会话在大约 5 秒后丢失。如何在 ionic http post 上启用 KeepAlive?
- php - 具有分组和跨列比较的 SQL (PHP)
- c++ - 国际象棋c ++控制台游戏,区分白色和黑色
- javascript - 孤立世界中的访问节点和电子模块
- html - 跳转到 html/css 中的某个位置