python-3.x - 这个 Rnn 函数的最后一行是什么意思?
问题描述
我是来问一个菜鸟问题的。
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size*2, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
out, _ = self.rnn(x, h0) # out: tensor of shape (batch_size, seq_length, hidden_size)
out = self.fc(out[:, -1, :])
return out
这里是什么out = self.fc(out[:, -1, :])
意思?还有为什么里面有一个“_” out, _ = self.rnn(x, h0)
?
解决方案
该行out = self.fc(out[:, -1, :])
使用负索引: out 是一个张量 shape batch_size x seq_length x hidden_size
,因此 out[:, 1, :] 将返回沿第二维(或轴)的第一个元素,并out[:, -1, :]
沿第二维返回最后一个元素。它相当于out[:, seq_length-1, :]
.
下划线 inout, _ = self.rnn(x, h0)
表示self.rnn(x, h0)
返回两个输出,out 分配给第一个输出,第二个输出没有分配给任何东西,所以_
是占位符。
推荐阅读
- python - 使用 python keras 训练 CNN 时出现 AttributeError
- core-data - CoreData+CloudKit iOS13 NSPersistentStoreRemoteChangeNotification
- python - 如何将视频流从 python 传输到电子?
- google-vault-api - 创建组时出现内部服务器错误
- vuejs2 - 如何使用 vue 路由器添加表单部分 HTML?
- c++ - 通过 C++ 中的重载构造函数初始化未知类型的变量
- java - 我试图将数组存储到 txt 文件
- powershell - Jenkins Pipeline Script - 评估批处理/powershell的标准输出
- react-native - 处理生物特征认证
- scala - SBT:如何依赖同一项目中的交叉编译模块