首页 > 解决方案 > 了解 PyTorch Linear 的工作原理

问题描述

我正在考虑文档中的示例代码:

import torch
from torch import nn
#
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())

输出是:

torch.Size([128, 30])

线性的构造函数是:

def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:

这与创建实例的方式一致,即:

m = nn.Linear(20, 30)

但是,当使用 m 时,它会收到一个张量

output = m(input)

作为输入。我不懂为什么。这个张量在源代码中定义在哪里?

标签: pytorch

解决方案


当您这样做时m(input),将调用__call__什么是__call__)方法,该方法在内部调用forward方法并执行其他操作。这个逻辑写在基类中:nn.Module. 为简单起见,现在假设做m(input)相当于m.forward(input).

输入是forward什么?张量。

def forward(self, input: Tensor) -> Tensor

推荐阅读