python - 有人可以解释这个 pytorch 神经网络代码吗?这里有两种不同的神经网络还是一种?
问题描述
class doubleNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(doubleNetwork, self).__init__()
self.policy1 = nn.Linear(input_dim, 256)
self.policy2 = nn.Linear(256, output_dim)
self.value1 = nn.Linear(input_dim, 256)
self.value2 = nn.Linear(256, 1)
def forward(self, state):
logits = F.relu(self.policy1(state))
logits = self.policy2(logits)
value = F.relu(self.value1(state))
value = self.value2(value)
return logits, value
是
policy1
,value1
在不同的网络中还是相同的?这里有两个不同的神经网络还是一个?
这里的代码发生了什么?
解决方案
您有两个并行的网络。你可以在 forward 方法中看到它:
state
-> policy1
-> policy2
->logits
state
-> value1
-> value2
->value
policy1
、policy2
和是 4 个不同value1
且value2
独立的全连接(线性)层。该nn.Linear
方法每次调用时都会创建一个新的神经元层。
编辑以获取更多详细信息:
在您的代码中定义一个doubleNetwork
类,该__init__
方法将在您创建此类的对象时调用
所以这一行:
my_network = doubleNetwork(10,15)
调用该__init__
方法,并创建一个新的 doubleNetwork 对象。newtork 将有 4 个属性 value1、value1、policy1、policy2,它们是全连接层。
该行:
self.policy1 = nn.Linear(input_dim, 256)
创建一个新的线性对象,它是一个完全连接的层,当这条线被执行时,层的权重被初始化。
network的forward
方法定义了调用网络对象时追加的内容。例如像这样的一行:
output1, output2 = my_network(input)
forward 中编写的代码是应用于输入的函数。这里作为状态的输入被传递到一侧的两个策略层,然后传递到两个值层。然后返回两个输出。所以网络是一个输入和两个输出的分叉形式。
在这段代码中,它是一个网络,但由于两个输出仅依赖于输入并且彼此独立,我们可以将它们定义在两个独立的网络中,结果相同。例如看代码:
class SingleNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(doubleNetwork, self).__init__()
self.layer1 = nn.Linear(input_dim, 256)
self.layer2 = nn.Linear(256, output_dim)
def forward(self, state):
output = F.relu(self.layer1(state))
output = self.layer2(output)
return output
my_network1 = singleNetwork(10,15)
my_network2 = singleNetwork(10,1)
然后:
output1 = my_network1(input)
output2 = my_network2(input)
将相当于
output1, output2 = my_network(input)
推荐阅读
- go - 如何从 *http.Request 和 *httptest.ResponseRecorder 创建 gin.Context?
- node.js - Node.js - 服务器不是构造函数() ES6
- r - 使用 checkboxGroupInput 作为数字输入的最小值,Shiny R
- mysql - 带有自然语言搜索的 MySQL LIKE 运算符
- neo4j - 如何测量 Neo4j 的索引速度
- ios - 如何在 iOS Swift 中附加二维数组?
- php - 如何通过 React js 在 php 中发布数据
- javascript - 从电子中的 html 中检索一个类
- c# - 如何使用 PDFSharp 将输入文本字段添加到 Pdf (AcroForm)
- c - 函数和字符串,检查输入字符串是否匹配