首页 > 解决方案 > 了解何时在 Pytorch 中使用 python 列表

问题描述

基本上正如这个线程在这里讨论的那样,你不能使用 python list 来包装你的子模块(例如你的层);否则,Pytorch 不会更新列表中子模块的参数。相反,您应该使用nn.ModuleList包装子模块以确保它们的参数将被更新。现在我还看到了如下代码,作者使用 python list 计算损失,然后进行loss.backward()更新(在强化算法中)。这是代码:

 policy_loss = []
    for log_prob in self.controller.log_probability_slected_action_list:
        policy_loss.append(- log_prob * (average_reward - b))
    self.optimizer.zero_grad()
    final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
    final_policy_loss.backward()
    self.optimizer.step()

为什么使用这种格式的列表可以更新模块的参数,但第一种情况不起作用?我现在很困惑。如果我更改前面的代码policy_loss = nn.ModuleList([]),它会抛出一个异常,说张量浮点不是子模块。

标签: deep-learningpytorchbackpropagation

解决方案


您误解了Modules 是什么。AModule存储参数并定义前向传递的实现。

您可以使用张量和参数执行任意计算,从而产生其他新张量。Modules不需要知道那些张量。您还可以在 Python 列表中存储张量列表。调用backward时,它需要在标量张量上,因此是串联的总和。这些张量是损失而不是参数,因此它们不应该是 a 的属性,Module也不应该包含在 a 中ModuleList


推荐阅读