首页 > 解决方案 > 了解 Pytorch 过滤器功能

问题描述

我正在浏览 PyTorch 框架的文档,发现很多实例都给变量分配了一个函数,但是当它调用函数时,参数会发生变化。不确定这是如何工作的,任何指针都会有所帮助。

我所理解的——

def func1(word):
    print("hello", word)
var1 = func1

现在在这种情况下,var1("world")将打印字符串hello world

但我不明白的是 PyTorch 的一些行,例如:

def __init__(self, input_size, num_classes): 
    super(NN,self).__init__()
    self.fc1 = nn.Linear(input_size, 50)  
    self.fc2 = nn.Linear(50, num_classes)  
def forward(self,x):
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

我们怎么知道只有 1 个参数应该传递给self.fc2. 它似乎与nn.Linear nn.Linear中定义的参数数无关 如果是这样,是否有任何关于退回的文件?

我确实找到了 nn 模块中每个函数的用法,但是有没有什么可以详细说明它是如何工作的?

标签: pythonpytorch

解决方案


nn.Linear不是一个函数(也不是所有其他层,如卷积层、batchnorms ......),而是一个函子,这意味着它是一个实现__call__方法/运算符的类,当您编写类似self.fc2(x).

__call__运算符在nn.Module基类中实现,它是对另一个方法的调用,该方法_call_impl本身(基本上)调用该forward方法。因此,多亏了继承魔法,当你让一个类派生自 时nn.Module,你只需要实现这个forward方法。

这个方法的签名有点取决于你,但在大多数情况下,它会接受一个张量作为输入并返回另一个张量。

总之 :

# calls the constructor of nn.Linear. self.fc1 is now a functor
self.fc1 = nn.Linear(20, 10)
# calls the fc1 functor on an input 
y = self.fc1(torch.randn(2, 10))
# which is basically doing
y = self.fc1.forward(torch.randn(2, 10))

推荐阅读