python - LSTM 中的 HiddenState 声明问题
问题描述
我想使用一个使用 PyTorch 框架的 LSTM 神经网络来预测多变量 (18) 时间序列(销售)。我有 16 个特性(每个特性的滞后变量、价格、其他因素)——总共 288 个。类声明如下所示:
class LSTMModel(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTMModel, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.output_size, batch_first=True)
self.fc = torch.nn.Linear(self.hidden_size, self.output_size)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.hidden_size).requires_grad_()
c0 = torch.zeros(1, x.size(0), self.hidden_size).requires_grad_()
out, (hn, cn) = self.lstm(x, (h0.detach_(), c0.detach_()))
out = self.fc(out[:, -1, :])
return out
Input_size 这里 288, hidden_size 96, output_size 18。然后我初始化我的模型:
hidden_size = 96
model = LSTMModel(X_all.shape[1], hidden_size, y_all.shape[1])
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
然后我正在尝试训练我的模型(batch_size 为 100):
model.train()
epochs = 500
train_losses = []
for epoch in range(1, epochs):
batch_losses = []
iter = 0
for x_batch, y_batch in train_loader:
x_batch = x_batch.view([x_batch.shape[0], -1, X_all.shape[1]])
optimizer.zero_grad()
y_pred = model(x_batch)
y_pred[y_pred < 0] = 0
loss = criterion(y_pred.squeeze(), y_train[iter*batch_size:min(iter*batch_size+batch_size, X_all.shape[0])])
batch_losses.append(loss.item())
loss.backward()
optimizer.step()
iter +=1
training_loss = np.mean(batch_losses)
print('Epoch {}: train loss: {}'.format(epoch, training_loss.item()))
train_losses.append(training_loss)
y_pred_train_n = y_pred.squeeze().detach().numpy()
并在以下行中有一个奇怪的错误:
out, (hn, cn) = self.lstm(x, (h0.detach_(), c0.detach_()))
“RuntimeError:预期隐藏 [0] 大小 (18, 826, 96),得到 [1, 826, 96]”
我究竟做错了什么?从 pytorch LSTM 的文档中:
h_0: 形状张量 (D * \text{num_layers}, N, H_{out})
我有 1 个隐藏层,NN 不是双向的,所以隐藏向量的第一个暗淡应该是 1。为什么是 18?如果我在 LSTM 类声明中更改第一个 dim(18 而不是 1),那么它可以工作。问题是什么?
解决方案
推荐阅读
- c# - 通过静态类访问 HttpContext 可以“正确”处理不同的请求
- google-chrome - 定义 Chrome IE 标签特定样式表
- php - PHP Session 不适用于某些(有时)移动用户?
- php - Laravel Cashier 10 未通过作曲家安装
- ios - 如何限制应用程序仅在 iphone 中下载而不在 iPad 中下载
- php - Octobercms - 表单上的 CSRF 保护以防止多次提交
- ios - 将 Xcode 更新到 10.4 版后 Apple Pay 无法正常工作
- java - 具有不同单元格高度的列表视图
- database - SQL * PLUS 在 glogin.sql 文件中定义 CONNECT_IDENTIFIER
- python-3.x - 在 Windows 10 上安装 fbprophet Python 时出现 Numpy 错误