nlp - 如何在pytorch中获得双向2层GRU的最终隐藏状态
问题描述
我正在努力理解如何获取隐藏层并将它们连接起来。
我以下面的代码为例:
class classifier(nn.Module):
#define all the layers used in model
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers,
bidirectional, dropout):
#Constructor
super().__init__()
self.batch = BATCH_SIZE
self.hidden = hidden_dim
self.layers = n_layers
if(bidirectional):
self.directions = 2
else:
self.directions = 1
#embedding layer
self.embedding = nn.Embedding(vocab_size, embedding_dim)
#lstm layer
self.gru = nn.GRU(embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
#dense layer
self.fc = nn.Linear(hidden_dim * 2, output_dim)
#activation function
self.act = nn.Sigmoid()
def forward(self, text, text_lengths):
#text = [batch size,sent_length]
embedded = self.embedding(text)
#embedded = [batch size, sent_len, emb dim]
#packed sequence
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths,batch_first=True)
packed_output, (hidden, cell) = self.lstm(packed_embedded)
#hidden = [batch size, num layers * num directions,hid dim]
#cell = [batch size, num layers * num directions,hid dim]
#concat the final forward and backward hidden state
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
#hidden = [batch size, hid dim * num directions]
dense_outputs=self.fc(hidden)
#Final activation function
outputs=self.act(dense_outputs)
return outputs
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
我没有得到它的行。
据我了解,我这样做是行不通的。
hidden2 = hidden.view(batch_size,self.layers,self.directions,self.hidden)
hidden2 = torch.cat((hidden2[:,:,0,:],hidden2[:,:,1,:]),dim=1)
dense_outputs=self.fc(hidden2)
有人可以解释一下。我浏览了 PyTorch 文档,但没有得到。
解决方案
双向 GRU 的隐藏输出的 shape[0] 为 2。您应该在 dim=1 上连接两个隐藏输出:
hid_enc = torch.cat([hid_enc[0,:, :], hid_enc[1,:,:]], dim=1).unsqueeze(0)
作为使用 -1 和 -2 作为索引的解释,正如您在 python 列表中所知道的,索引 -1 中的对象是列表的最后一个对象(我们的张量列表中的第二个对象),索引 -2 指的是最后一个对象之前的对象(在我们的例子中是第一个对象)。所以你没看懂的代码就相当于我回答中的代码
推荐阅读
- javascript - ReactJS:如何从 useState 钩子中导出变量?
- asp.net-mvc - Ñ 字符在 HTML SELECT 中错误地显示为 Ñ
- python - CVXPY 对转子平衡问题的约束
- c - gdb 中的分段错误是否显示物理地址或虚拟地址?
- sql - 在条件 ORACLE SQL 中使用 max
- firebase - 如何从 Web 应用程序发布事件和指标以便于分析它们?
- amazon-web-services - 我可以有条件地在 AWS Appsync 解析器中调用 lambda 函数吗?
- javascript - 设置超时后无法访问 Javascript 函数
- python - 为什么 super() 继承了“错误”的类?
- python - 函数内的 Python 全局变量