python - LSTM 隐藏状态维度存在错误:RuntimeError: Expected hidden[0] size (4, 1, 256), got (1, 256)
问题描述
我正在 PyTorch 中试验 seq2seq_tutorial。编码器的 lstm 隐藏状态大小似乎存在尺寸错误。
使用bidirectional=True
和num_layers = 2
,隐藏状态的形状应该是(num_layers*2, batch_size, hidden_size)
。
但是,出现错误并显示以下消息:
RuntimeError: Expected hidden[0] size (4, 1, 256), got (1, 256)
首先,我尝试重塑隐藏状态以使用不同的形状初始化隐藏状态,但似乎没有任何效果。
这是我的代码的 train 方法:
def train(self, input, target, encoder, decoder, encoder_optim, decoder_optim, criterion):
enc_optimizer = encoder_optim
dec_optimizer = decoder_optim
enc_optimizer.zero_grad()
dec_optimizer.zero_grad()
pair = (input, target)
input_len = input.size(0)
target_len = target.size(0)
enc_output_tensor = torch.zeros(self.opt['max_seq_len'], encoder.hidden_size, device=device)
enc_hidden = encoder.cuda().initHidden(device)
for word_idx in range(input_len):
print('Input:', input[word_idx], '\nHidden shape:', enc_hidden.size())
enc_output, enc_hidden = encoder(input[word_idx], enc_hidden)
enc_output_tensor[word_idx] = enc_output[0,0]
这是我的代码的编码器方法:
class EncoderBRNN(nn.Module):
# A bidirectional rnn based encoder
def __init__(self, input_size, hidden_size, emb_size, batch_size=1, num_layers=2, bidir=True):
super(EncoderBRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.batch_size = batch_size
self.embedding_dim = emb_size
self.num_layers = num_layers
self.bidir = bidir
self.embedding_layer = nn.Embedding(self.input_size, self.embedding_dim)
self.enc_layer = nn.LSTM(self.embedding_dim, self.hidden_size, num_layers=self.num_layers, bidirectional=self.bidir)
def forward(self, input, hidden):
embed = self.embedding_layer(input).view(1, 1, -1)
output, hidden = self.enc_layer(embed, hidden)
return output, hidden
def initHidden(self, device):
if self.bidir:
num_stacks = self.num_layers * 2
else:
num_stacks = self.num_layers
return torch.zeros(num_stacks, self.batch_size, self.hidden_size, device=device)
解决方案
我知道这是不久前有人问过的,但我想我在这个火炬讨论中找到了答案。相关资料:
LSTM 采用隐藏状态的元组: self.rnn(x, (h_0, c_0)) 看起来你还没有在第二个隐藏状态下发送?
您还可以在LSTM的文档中看到这一点
推荐阅读
- go - 无法从 GO 中的其他文件导入包
- mule - 想知道 mule 中的重新传递策略配置对文件端点连接器的作用
- objective-c - 背景 蓝色 按钮 图像 obj c
- c - 函数在 C 中返回垃圾值但不是全部
- javascript - 如何选择html单元格并通过像日历一样拖动来突出显示
- angular - ADAL .Net Core MVC 重定向到 Angular 应用
- visual-studio-code - 如何在 VS Code 中同时查找文件夹和文件?
- octave - 将 MATLAB .m 脚本转换为 Octave 兼容语法时的维度问题
- jquery - 刷新网页时如何保留所需文件
- firebase-authentication - 后退按钮直接到登录页面而不是电话(主页)