python - Pytorch NN 和类之间的通信
问题描述
我是 python 和 pytorch 的新手,但在理解它的工作原理时遇到了问题。
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
..
def forward(self, x):
..
return x
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
所以,这就是代码,我从图中的代码中画出我所理解的。我有一些问题:
A) 为什么我不能直接在代码中使用 nn.CrossEntropy 而不是 'criterion'?如果我将它分配给一个变量,它有什么区别?我收到此错误:具有多个值的张量的布尔值不明确
B) 为什么当 Net 类获得一个对象 (nn) 时(我假设使用 'as' 时会创建一个对象),然后 Net 类可以简单地向后使用?它应该是 nn 的一部分,而不是 Net。你能帮我理解一下吗?
C) 虽然 optim 是一个不同的对象,但由 optim 优化的参数如何影响 nn?我不明白他们如何传递变量并相互更新?
解决方案
A)通过在一个位置将其设置为变量,它有助于更容易地在一个位置更改损失函数,而不是随着代码的大小和复杂性的增加而在许多地方键入 nn.MSELoss。基本上不太可能出错。
至于错误,需要更多信息来回答该布尔错误。在哪一行输入什么等。信息太少,无法提供帮助。
B) Net(nn.Module) 从 nn.Module 继承,它将向后添加到您添加到类中的所有操作。有关更多信息,请参阅文档。
C)“网”是一个对象。net.parameters() 是一个迭代器,它迭代网络对象中的所有参数。所以它是通过引用传递而不是通过值传递参数。
推荐阅读
- reactjs - 如何将material-ui抽屉分为header、main和Footer部分?
- ios - 在通知中点击并且用户未登录时转到特定的视图控制器
- scala - 如何使用类型别名来定义类型构造函数
- python - 检测文本中的小写首字母缩写词
- javascript - 如何使用小书签抓取部分 URL
- node.js - UnhandledPromiseRejectionWarning: MongoNetworkError: failed to connect to server on first connect with Heroku server
- sql-server - “-”附近的语法不正确
- vue.js - Webpack 没有插入 vue 构建
- python-3.x - 如何找到客户每次访问之间的天数差异
- sql - 如何在不更新表格的情况下更改字段?