machine-learning - Pytorch 中 GRU 单元的隐藏和输出是否相同?
问题描述
我在概念上确实理解 LSTM 或 GRU 应该做什么(感谢这个问题What's the difference between "hidden" and "output" in PyTorch LSTM?)但是当我检查 GRU 的输出时h_n
,output
它们应该是不一样的。 ..
(Pdb) rnn_output
tensor([[[ 0.2663, 0.3429, -0.0415, ..., 0.1275, 0.0719, 0.1011],
[-0.1272, 0.3096, -0.0403, ..., 0.0589, -0.0556, -0.3039],
[ 0.1064, 0.2810, -0.1858, ..., 0.3308, 0.1150, -0.3348],
...,
[-0.0929, 0.2826, -0.0554, ..., 0.0176, -0.1552, -0.0427],
[-0.0849, 0.3395, -0.0477, ..., 0.0172, -0.1429, 0.0153],
[-0.0212, 0.1257, -0.2670, ..., -0.0432, 0.2122, -0.1797]]],
grad_fn=<StackBackward>)
(Pdb) hidden
tensor([[[ 0.1700, 0.2388, -0.4159, ..., -0.1949, 0.0692, -0.0630],
[ 0.1304, 0.0426, -0.2874, ..., 0.0882, 0.1394, -0.1899],
[-0.0071, 0.1512, -0.1558, ..., -0.1578, 0.1990, -0.2468],
...,
[ 0.0856, 0.0962, -0.0985, ..., 0.0081, 0.0906, -0.1234],
[ 0.1773, 0.2808, -0.0300, ..., -0.0415, -0.0650, -0.0010],
[ 0.2207, 0.3573, -0.2493, ..., -0.2371, 0.1349, -0.2982]],
[[ 0.2663, 0.3429, -0.0415, ..., 0.1275, 0.0719, 0.1011],
[-0.1272, 0.3096, -0.0403, ..., 0.0589, -0.0556, -0.3039],
[ 0.1064, 0.2810, -0.1858, ..., 0.3308, 0.1150, -0.3348],
...,
[-0.0929, 0.2826, -0.0554, ..., 0.0176, -0.1552, -0.0427],
[-0.0849, 0.3395, -0.0477, ..., 0.0172, -0.1429, 0.0153],
[-0.0212, 0.1257, -0.2670, ..., -0.0432, 0.2122, -0.1797]]],
grad_fn=<StackBackward>)
它们是相互转置的……为什么?
解决方案
它们并不完全相同。考虑我们有以下单向GRU 模型:
import torch.nn as nn
import torch
gru = nn.GRU(input_size = 8, hidden_size = 50, num_layers = 3, batch_first = True)
请确保仔细观察输入形状。
inp = torch.randn(1024, 112, 8)
out, hn = gru(inp)
确实,
torch.equal(out, hn)
False
帮助我理解输出与隐藏状态的最有效方法之一是查看双向循环网络的hn
位置hn.view(num_layers, num_directions, batch, hidden_size)
(num_directions = 2
另一种方式,即我们的案例)。因此,
hn_conceptual_view = hn.view(3, 1, 1024, 50)
正如文档所述(注意斜体和粗体):
h_n of shape (num_layers * num_directions, batch, hidden_size):包含 t = seq_len 的隐藏状态的张量(即最后一个时间步)
在我们的例子中,这包含时间步长的隐藏向量t = 112
,其中:
形状的输出(seq_len,batch,num_directions * hidden_size):张量包含来自GRU的最后一层的输出特征 h_t ,对于每个 t。如果一个 torch.nn.utils.rnn.PackedSequence 作为输入,输出也将是一个打包序列。对于未打包的情况,可以使用 output.view(seq_len, batch, num_directions, hidden_size) 来分离方向,向前和向后分别是方向 0 和 1。
因此,因此,可以这样做:
torch.equal(out[:, -1], hn_conceptual_view[-1, 0, :, :])
True
解释:我将所有批次的最后一个序列out[:, -1]
与最后一层隐藏向量进行比较hn[-1, 0, :, :]
对于双向GRU(需要先读取单向):
gru = nn.GRU(input_size = 8, hidden_size = 50, num_layers = 3, batch_first = True bidirectional = True)
inp = torch.randn(1024, 112, 8)
out, hn = gru(inp)
视图更改为(因为我们有两个方向):
hn_conceptual_view = hn.view(3, 2, 1024, 50)
如果您尝试确切的代码:
torch.equal(out[:, -1], hn_conceptual_view[-1, 0, :, :])
False
解释:这是因为我们甚至在比较错误的形状;
out[:, 0].shape
torch.Size([1024, 100])
hn_conceptual_view[-1, 0, :, :].shape
torch.Size([1024, 50])
请记住,对于双向网络,隐藏状态在每个时间步被连接起来,其中第一个hidden_state
大小(即)是前向网络的隐藏状态,另一个大小是后向网络(即)。前向网络的正确比较是:out[:, 0,
:50
]
hidden_state
out[:, 0,
50:
]
torch.equal(out[:, -1, :50], hn_conceptual_view[-1, 0, :, :])
True
如果您想要后向网络的隐藏状态,并且由于后向网络从时间步开始处理序列n ... 1
。您比较序列的第一个时间步,但最后一个hidden_state
大小并将hn_conceptual_view
方向更改为1
:
torch.equal(out[:, -1, :50], hn_conceptual_view[-1, 1, :, :])
True
简而言之,一般来说:
单向:
rnn_module = nn.RECURRENT_MODULE(num_layers = X, hidden_state = H, batch_first = True)
inp = torch.rand(B, S, E)
output, hn = rnn_module(inp)
hn_conceptual_view = hn.view(X, 1, B, H)
RECURRENT_MODULE
GRU 或 LSTM(在撰写本文时)在哪里,B
是批量大小、S
序列长度和E
嵌入大小。
torch.equal(output[:, S, :], hn_conceptual_view[-1, 0, :, :])
True
我们再次使用S
,因为它rnn_module
是前向的(即单向的)并且最后一个时间步长存储在序列长度S
中。
双向:
rnn_module = nn.RECURRENT_MODULE(num_layers = X, hidden_state = H, batch_first = True, bidirectional = True)
inp = torch.rand(B, S, E)
output, hn = rnn_module(inp)
hn_conceptual_view = hn.view(X, 2, B, H)
比较
torch.equal(output[:, S, :H], hn_conceptual_view[-1, 0, :, :])
True
以上是前向网络比较,我们之所以使用前向网络,是:H
因为前向将其隐藏向量存储在H
每个时间步的第一个元素中。
对于后向网络:
torch.equal(output[:, 0, H:], hn_conceptual_view[-1, 1, :, :])
True
我们将方向更改hn_conceptual_view
为1
以获取后向网络的隐藏向量。
对于我们使用的所有示例,hn_conceptual_view[-1, ...]
因为我们只对最后一层感兴趣。
推荐阅读
- python-3.x - 循环遍历字符串中的每个字符,然后用首字母缩写词替换所述字符,而不使用 replace() 或 find() 函数
- php - 为什么只向 mysql 数据库插入一列不会在 codeigniter 中显示错误?
- php - 当有人仅使用 PHP 和 HTML 在表单中输入内容时,如何创建新的表格行?
- java - 如何在 webclient 中返回 Flux.just 状态
- python - Python相当于r中的子集函数
- python - 'numpy.ndarray' 对象没有使用 reshape 的属性 'values'
- html - 离子从一页导航到另一页,不起作用
- php - 如何从数据库中列出 DISTINCT ALL 产品
- r - R 误差中的多线性回归:0(非 NA)案例。我怎样才能让它工作?
- javascript - 获取 TextField 组件的值