python - pytorch中的神经网络定义如何使用pyton类
问题描述
为了了解这段代码的工作原理,我编写了一个小型复制器。self.hidden 变量如何在 forward 方法中使用变量 x?
enter code class Network(nn.Module):
def __init__(self):
super().__init__()
# Inputs to hidden layer linear transformation
self.hidden = nn.Linear(784, 256)
# Output layer, 10 units - one for each digit
self.output = nn.Linear(256, 10)
# Define sigmoid activation and softmax output
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
# Pass the input tensor through each of our operations
x = self.hidden(x)
x = self.sigmoid(x)
x = self.output(x)
x = self.softmax(x)
return x
解决方案
你误解了什么self.hidden = nn.Linear(784, 256)
。你写道:
hidden
被定义为一个函数
但是这是错误的。self.hidden
是类的一个对象nn.Linear
。当您打电话时self.hidden(...)
,您不会将参数传递给nn.Linear
; 您将参数传递给__call__
(在nn.Linear
类中定义)。
如果您想了解更多详细信息,我已经扩展了它在 PyTorch 中的工作原理:请参阅此答案。
推荐阅读
- angular-material - Angular Material:奇怪的后缀功能
- elasticsearch - Elasticsearch bucket_selector 包括父聚合大小
- java - 在java中拆分字符串的问题
- powerbi - Power BI - TopoJSON 的形状图渲染问题
- excel - 运行时错误“91”:对象变量或未设置块
- swift - SwiftUI:如何在内容视图中显示函数的输出
- python - Webscraping - 我需要一些帮助来理解如何区分页面上的项目 BS4,请求
- c++ - 在 Windows Server 2019 上,如何让两个 MFC 应用程序使用窗口句柄和 GetDC() 进行通信?
- reporting-services - SSRS 提示输入凭据 HTTPS 但不是 HTTP
- reactjs - Reactjs:使用 libarchivejs 处理 7z 文件:Archive.open() 被冻结并且没有给出错误或结果