首页 > 解决方案 > Module.parameters() 如何找到参数?

问题描述

我注意到,每当您创建一个新的网络扩展torch.nn.Module时,您都可以立即调用net.parameters()以查找与反向传播相关的参数。

import torch

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc = torch.nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x)

net = MyNet()
print(list(net.parameters()))

但后来我想知道,这怎么可能?我只是将此Linear图层对象分配给一个成员变量,但它没有记录在其他任何地方(或者是吗?)。必须能够以某种MyNet方式跟踪使用的参数,但如何?

标签: python-3.xpytorch

解决方案


真的很简单,只需通过元编程检查属性并检查它们的类型

class Example():
    def __init__(self):
        self.special_thing = nn.Parameter(torch.rand(2))
        self.something_else = "ok"

    def get_parameters(self):
        for key, value in self.__dict__.items():
            if type(value) == nn.Parameter:
                print(key, "is a parameter!")


e = Example()
e.get_parameters()
# => special_thing is a parameter!

推荐阅读