python - RuntimeError: Expected hidden size (2, 24, 50), got (2, 30, 50)
问题描述
我正在尝试建立一个模型来学习数据集中某些句子的分配分数(实数)。为此,我使用 RNN(在 PyTorch 中)。我定义了一个模型:
class RNNModel1(nn.Module):
def forward(self, input ,hidden_0):
embedded = self.embedding(input)
output, hidden = self.rnn(embedded, hidden_0)
output=self.linear(hidden)
return output , hidden
训练函数如下:
def train(model,optimizer,criterion,BATCH_SIZE,train_loader,clip):
model.train(True)
total_loss = 0
hidden = model._init_hidden(BATCH_SIZE)
for i, (batch_of_data, batch_of_labels) in enumerate(train_loader, 1):
hidden=hidden.detach()
model.zero_grad()
output,hidden= model(batch_of_data,hidden)
loss = criterion(output, sorted_batch_target_scores)
total_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), clip)
optimizer.step()
return total_loss/len(train_loader.dataset)
当我运行代码时,我收到此错误:
RuntimeError: Expected hidden size (2, 24, 50), got (2, 30, 50)
批量大小=30,隐藏大小=50,层数=1,双向=真。
我在最后一批数据中收到该错误。我检查了 PyTorch 中对 RNN 的描述来解决这个问题。PyTorch 中的 RNN 有两个输入参数和两个输出参数。输入参数是input和h_0。h_0是一个张量,包括批量大小(num_layers*num_directions, batch, hidden size)中每个元素的初始隐藏状态。输出参数是output ans h_n。h_n是一个张量,包括大小为 t=seq_len 的隐藏状态(num_layers*num_directions,batch,hidden size)。
在所有批次(最后一批除外)中,h_0 和 h_n 的大小相同。但在最后一批中,元素数量可能小于批量大小。因此 h_n 的大小是 (num_layers num_directions,tained_elements_in_last_batch, hidden size) 但 h_0 的大小仍然是 (num_layers num_directions, batch_size, hidden size)。
所以我在最后一批数据中收到了那个错误。
如何解决这个问题并处理 h_0 和 h_n 的大小不同的情况?
提前致谢。
解决方案
当数据集中的样本数量不是批次大小的倍数时,会发生此错误。忽略最后一批可以解决问题。要识别最后一批,请检查每批中的元素数量。如果小于 BATCH_SIZE,那么它是数据集中的最后一批。
if(len(batch_of_data)==BATCH_SIZE):
output,hidden= model(batch_of_data,hidden)
推荐阅读
- javascript - 如何循环一个函数两次以上?
- python - 用scrapy刮问题
- csv - 如何在 Svelte 应用程序中使用 Axios 获取 CSV?
- mysql - Google App Script SQL 查询返回 bool 'True' 而不是 Int 值,但查询在 App Script 之外工作?
- flutter - 令人敬畏的 fcm 推送通知
- c# - 如何在不应对的情况下重用 XAML 组件及其“.cs”代码?
- python - 每次循环重新启动时,如何防止 for 循环擦除它的输出?
- python-3.x - 禁用 Matplotlib 图表的恒定位置重置
- asp.net-core - 级联参数Task是怎么做的
在 Blazor WASM 中的 AuthorizeView 和 AuthorizedRouteView 中解包并公开为“上下文”? - angular - 显示当前用户(Angular 和 Firebase)