deep-learning - pytorch seq2seq 编码器正向方法
问题描述
我正在关注Pytorch seq2seq 教程,下面是他们如何定义编码器函数。
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output = embedded
output, hidden = self.gru(output, hidden)
return output, hidden
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
但是,似乎forward
在训练期间从未真正调用方法。
以下是本教程中如何使用编码器转发方法:
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
它不应该是encoder.forward
而不是只是encoder
?Pytorch 中是否有一些我不知道的自动“前进”机制?
解决方案
在 PyTorch 中,您可以通过扩展和定义 forward 方法来编写自己的类,torch.nn.Module
以表达您所需的计算步骤,这些步骤充当方法中的“文书工作”(例如调用挂钩)model.__call__(...)
(这是 model(x) 将通过 python 特殊名称调用的内容)规格)。
如果您好奇,您可以查看除了在这里model(x)
调用之外的幕后操作: https ://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L462model.forward(x)
此外,您可以看到显式调用该.foward(x)
方法与仅model(x)
在此处简单使用之间的区别:https ://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L72
推荐阅读
- vbscript - VBScript:从网页下载 JSON 文件并将内容读取到变量
- ios - 如何从 AVAudioBuffer 音频样本中确定音量
- next.js - NextJs - Apollo - 第一次渲染不是 SSR
- sql - 没有 CREATE TABLE 的连接列上的 SQL JOIN
- sqlite - 当从文字中的相同总数中减去总和时,SQLite 会显示浮点数
- python - 在熊猫数据框中组合多行
- reactjs - 收到警告:列表中的每个孩子都应该有一个唯一的“关键”道具。在传播 props 和组合多个 Mui 组件时
- spring-integration - WARN JdbcChannelMessageStore 带有 id 的消息没有被删除
- python - cv2.imshow([winname], [mat]) 后的分段错误
- angular - AngularFirestore 动态查询