python - pytorch nn.Module 推理
问题描述
我打算学习 Pytorch。但是在这个阶段我想问一个问题,以便我可以理解我正在阅读的一些代码
当你有一个基类是的类nn.Module
时
class My_model(nn.Module)
应该如何在那里进行推理?
在我正在阅读的代码中,它说
tasks_output, other = my_model(data)
那不就是创建一个对象吗?(比如调用类构造函数)
在 pytorch 中,应该如何进行推理?
(作为参考,我说的是什么时候my_model
设置为my_model.eval()
)
编辑:我很抱歉。我犯了将类和对象声明为一个的错误。我更正了代码
解决方案
你有例如:
class My_model(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Call construtor of Class
my_model = My_model()
区分类和对象很重要。名称的类在 Python 中以大写字母开头。
如您所见,构造函数不带数据/输入参数,只有函数 forward 有一个。
之后,对于培训,您必须需要:
- 标准谁计算带有标签的模型的误差。
- 它必须具有反向传播算法的优化器
例子:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
最后,您必须通过循环需要以下元素:
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
在这里,您有一次反向传播迭代。
如果您想考虑反向传播中的推理,您可以阅读如何使用 pytorch 创建图层以及 pytorch 如何使用签名。
张量使用 Autograph 进行反向传播。Pytorch文档示例
import torch
x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()
print(w.grad)
print(b.grad)
结果给出了反向传播,其中交叉熵准则计算模型和标签的距离。张量 z 不是唯一的值矩阵,而是具有 w、b、x、y 的“记忆计算”的类。
在该层中,梯度使用前向函数进行此计算,或者在必要时使用后向函数。
最良好的问候
推荐阅读
- javascript - RadioGroup 中的动态 FormControlLabel Radio 没有得到检查
- reactjs - 将对象从子状态传递给父状态
- r - 是否可以在 for 循环中使用 aov 函数?
- react-native - React Native Android 调试构建在启动时崩溃,iOS 工作正常
- java - 我无法创建一种方法来计算出现次数
- api - Laravel 5.8 休息客户端如何在 .env 中保存 api 令牌
- spring-boot - 如何使用 Spring boot 在 OAuth2.0 中添加用于授权和身份验证的自定义逻辑?
- android - 为什么 setOnclicklistener 不能在无尽回收器视图中的特定行中工作?
- python - 如果在目标表中找不到记录(与源具有相同的索引号),则比较两个数据帧(源与目标)并留空行
- codenameone - 允许您指定 iOS 构建的构建提示