python - 了解 Pytorch 过滤器功能
问题描述
我正在浏览 PyTorch 框架的文档,发现很多实例都给变量分配了一个函数,但是当它调用函数时,参数会发生变化。不确定这是如何工作的,任何指针都会有所帮助。
我所理解的——
def func1(word):
print("hello", word)
var1 = func1
现在在这种情况下,var1("world")
将打印字符串hello world
。
但我不明白的是 PyTorch 的一些行,例如:
def __init__(self, input_size, num_classes):
super(NN,self).__init__()
self.fc1 = nn.Linear(input_size, 50)
self.fc2 = nn.Linear(50, num_classes)
def forward(self,x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
我们怎么知道只有 1 个参数应该传递给self.fc2
. 它似乎与nn.Linear
nn.Linear中定义的参数数无关 如果是这样,是否有任何关于退回的文件?
我确实找到了 nn 模块中每个函数的用法,但是有没有什么可以详细说明它是如何工作的?
解决方案
nn.Linear
不是一个函数(也不是所有其他层,如卷积层、batchnorms ......),而是一个函子,这意味着它是一个实现__call__
方法/运算符的类,当您编写类似self.fc2(x)
.
该__call__
运算符在nn.Module基类中实现,它是对另一个方法的调用,该方法_call_impl
本身(基本上)调用该forward
方法。因此,多亏了继承魔法,当你让一个类派生自 时nn.Module
,你只需要实现这个forward
方法。
这个方法的签名有点取决于你,但在大多数情况下,它会接受一个张量作为输入并返回另一个张量。
总之 :
# calls the constructor of nn.Linear. self.fc1 is now a functor
self.fc1 = nn.Linear(20, 10)
# calls the fc1 functor on an input
y = self.fc1(torch.randn(2, 10))
# which is basically doing
y = self.fc1.forward(torch.randn(2, 10))
推荐阅读
- c - scanf 的第二个值没有改变
- html - 如何使用网站开放库将信息存储到数据库中
- reactjs - 使用表单中的选择值作为 API 参数来获取响应 - Reactjs
- yaml - 如何将地图添加到 ytt 中的地图数组中?
- openvpn - Pritunl 卡在“生成设置服务器 ssl 证书”
- python - Python:如何将可变长度前导零添加到二进制字符串?
- html - HTML5
标签,可以有属性吗? - javascript - 如何随机化 10 秒、20 秒、30 秒内一致显示的 3 个 div 的顺序
- apache-spark - 如何在字符串单词和数字的 RDD 中将数字字符串转换为 int?
- php - JSON 解析错误 - 来自 PHP json_encode 的 Javascript