python - LSTM.weight_ih_l[k] 维度与 proj_size
问题描述
根据 Pytorch LSTM 文档:-
- ~LSTM.weight_ih_l[k] – 第 k 层 (W_ii|W_if|W_ig|W_io) 的可学习输入隐藏权重,形状为 (4*hidden_size, input_size),k = 0。否则,形状为 (4 * hidden_size, num_directions * hidden_size)
我的疑问是,为什么每个k > 0
权重 的形状是(hidden_size, num_directions * hidden_size)
(hidden_size, num_directions * proj_size)
(L, N, num_directions*proj_size)
更新
事实上,文档中提到了错误的形状。很快就会修复。
解决方案
现在已修复。OP 打开了由PR #65102(提交83878e1 )修复的问题 #65053。
碰巧的是,在这种情况下,文档并未提供所有详细信息,您是对的。实际上,您可以在源代码中查看形状W_ih
为(4*hidden_size, num_directions * proj_size)
when proj_size > 0
for的源代码k > 0
:
# [...]
if mode == 'LSTM':
gate_size = 4 * hidden_size
elif mode == 'GRU':
gate_size = 3 * hidden_size
elif mode == 'RNN_TANH':
gate_size = hidden_size
elif mode == 'RNN_RELU':
gate_size = hidden_size
else:
raise ValueError("Unrecognized RNN mode: " + mode)
self._flat_weights_names = []
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
real_hidden_size = proj_size if proj_size > 0 else hidden_size
layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs))
# [...]
如您所见,w_ih
具有形状(gate_size, layer_input_size)
,其中:
gate_size
用于4 * hidden_size
LSTM,并且layer_input_size
是input_size
iflayer == 0
(layer
相当于k
文档中的),否则real_hidden_size * num_directions
对于k > 0
, 和
real_hidden_size = proj_size
如果proj_size > 0
,否则它是hidden_size
。
即:如果proj_size > 0
和layer > 0
,layer_input_size = proj_size * num_directions
和 的形状w_ih
将等于(4 * hidden_size, proj_size * num_directions
。
值得注意的是,他们在文档中确实包含以下内容:
如果
proj_size > 0
指定,将使用带有投影的 LSTM。这会以下列方式更改 LSTM 单元。首先, 的尺寸h_t
将从 更改hidden_size
为proj_size
(的尺寸W_hi
将相应更改)。其次,每一层的输出隐藏状态将乘以一个可学习的投影矩阵h_t = W_hr * h_t
:请注意,因此,LSTM 网络的输出也将具有不同的形状。
推荐阅读
- python - numpy intersect1d 意外关键字参数'return_indices'
- wso2 - 为 WSO2 Developer Studio 调色板设置自定义中介
- java - Java - 在当前时间和未来设定时间之间创建一个时间间隔数组(15 分钟)
- html - col-xl-12 不是 100% 的 div 父亲
- javascript - 失踪 ; 声明之前(Jira 插件 - oAuth)
- python - 我们如何使用在循环内部描述的变量,在python中的循环之外
- excel - 比较两个双打是否相等失败
- objective-c - Objective-C:变量不会被使用(如果)
- r - 使用 rle() 查找给定值的最大运行的索引位置
- javascript - Redux:选择器模式的另一种实现