python - 在pytorch LSTM上循环
问题描述
我正在为 pytorch 中的机器翻译训练一个 seq2seq 模型。我想在每个时间步收集细胞状态,同时仍然具有多层和双向的灵活性,例如,您可以在 pytorch 的 LSTM 模块中找到。
为此,我有以下编码器和转发方法,其中我循环 LSTM 模块。问题是,模型训练得不是很好。循环终止后,您可以看到使用 LSTM 模块的正常方式,然后模型训练。
那么,循环不是一种有效的方法吗?
class encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()
self.input_dim = input_dim
self.emb_dim = emb_dim
self.hid_dim = hid_dim
self.n_layers = n_layers
self.dropout = dropout
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
#src = [src sent len, batch size]
embedded = self.dropout(self.embedding(src))
#embedded = [src sent len, batch size, emb dim]
hidden_all = []
for i in range(len(embedded[:,1,1])):
outputs, hidden = self.rnn(embedded[i,:,:].unsqueeze(0))
hidden_all.append(hidden)
#outputs, hidden = self.rnn(embedded)
#outputs = [src sent len, batch size, hid dim * n directions]
#hidden = [n layers * n directions, batch size, hid dim]
#cell = [n layers * n directions, batch size, hid dim]
None
#outputs are always from the top hidden layer
return hidden
解决方案
好的,所以修复很简单,你可以在外面运行第一个时间步,得到一个隐藏的元组输入 LSTM 模块。
推荐阅读
- ruby-on-rails - 将 devise_invitable 与枚举用户角色一起使用
- c# - 努力在 C# 中使用 Newtonsoft.JSON 反序列化 JSON 字符串
- java - 将图像上传到 Firebase - Java
- python - 从 Azure 数据块连接到 Azure 表存储
- r - 从 R 的 flexdashboard 中的过滤器中选择时,在传单上添加标记
- ros - 跨机器/主机使用 ROS 自定义消息类型
- c# - 具有多个参数的 Web API Post
- javascript - 使用 javascript/jquery 中的逗号拆分将输入中的文本转换为显示为 href 链接
- javascript - SendMouseClickEvent 有时不起作用
- java - 使用 aws cognito 用户池登录 Facebook\Google