python - PyToch:ValueError:预期输入batch_size(256)与目标batch_size(128)匹配
问题描述
我在使用 pytorch 训练 BiLSTM 词性标注器时遇到了 ValueError。ValueError:预期输入 batch_size (256) 与目标 batch_size (128) 匹配。
def train(model, iterator, optimizer, criterion, tag_pad_idx):
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in iterator:
text = batch.p
tags = batch.t
optimizer.zero_grad()
#text = [sent len, batch size]
predictions = model(text)
#predictions = [sent len, batch size, output dim]
#tags = [sent len, batch size]
predictions = predictions.view(-1, predictions.shape[-1])
tags = tags.view(-1)
#predictions = [sent len * batch size, output dim]
#tags = [sent len * batch size]
loss = criterion(predictions, tags)
acc = categorical_accuracy(predictions, tags, tag_pad_idx)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def evaluate(model, iterator, criterion, tag_pad_idx):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for batch in iterator:
text = batch.p
tags = batch.t
predictions = model(text)
predictions = predictions.view(-1, predictions.shape[-1])
tags = tags.view(-1)
loss = criterion(predictions, tags)
acc = categorical_accuracy(predictions, tags, tag_pad_idx)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
class BiLSTMPOSTagger(nn.Module):
def __init__(self,
input_dim,
embedding_dim,
hidden_dim,
output_dim,
n_layers,
bidirectional,
dropout,
pad_idx):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
self.lstm = nn.LSTM(embedding_dim,
hidden_dim,
num_layers = n_layers,
bidirectional = bidirectional,
dropout = dropout if n_layers > 1 else 0)
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
embedded = self.dropout(self.embedding(text))
outputs, (hidden, cell) = self.lstm(embedded)
predictions = self.fc(self.dropout(outputs))
return predictions
…………………………………………………………………………………………………………………………………………………………………… …………………………………………………………………………………………………………………………………………………………………… ...........
INPUT_DIM = len(POS.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = len(TAG.vocab)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
PAD_IDX = POS.vocab.stoi[POS.pad_token]
print(INPUT_DIM) #output 22147
print(OUTPUT_DIM) #output 42
model = BiLSTMPOSTagger(INPUT_DIM,
EMBEDDING_DIM,
HIDDEN_DIM,
OUTPUT_DIM,
N_LAYERS,
BIDIRECTIONAL,
DROPOUT,
PAD_IDX)
…………………………………………………………………………………………………………………………………………………………………… …………………………………………………………………………………………………………………………………………………………………… ...........
N_EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
ValueError Traceback (most recent call last)
<ipython-input-55-83bf30366feb> in <module>()
7 start_time = time.time()
8
----> 9 train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
10 valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
11
4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2260 if input.size(0) != target.size(0):
2261 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (256) to match target batch_size (128).
解决方案
(继续评论)
我猜你的批量大小等于 128(它没有定义),对吧?LSTM 输出每个时间步的输出列表。但是对于分类,您通常只需要最后一个。所以第一个维度outputs
是你的序列长度,在你的情况下似乎是 2。当你应用.view
这两个乘以你的批量大小(128)然后它得到 256。所以在你的 lstm 层之后你需要采取最后输出 output-sequence outputs
。像这样:
def forward(self, text):
embedded = self.dropout(self.embedding(text))
outputs, (hidden, cell) = self.lstm(embedded)
# take last output
outputs = outputs.reshape(batch_size, sequence_size, hidden_size)
outputs = outputs[:, -1]
predictions = self.fc(self.dropout(outputs))
return predictions
推荐阅读
- node.js - nodejs mongodb 根据集合是否为空执行不同的查询
- apache - 如何从Combiner/Reducer/Aggregator 函数返回具有多个字段的元组?
- spring-boot - 无法使用spring aop在spring EntityManager中启用休眠过滤器
- javascript - 如何创建角度电子应用程序的exe包
- jasper-reports - 如果详细信息部分跨越多个页面,如何在组标题中重复单个波段?
- quarkus - 如何在没有 ContextNotActive 错误的 PanacheEntity 测试中使用 H2
- rest - 在 REST API 中执行多次删除
- python - 在另一个桌面 Matplotlib Python 上设置绘图
- c# - Azure SignalR 服务 | Asp.Net Web API | 控制台客户端
- lstm - pytorch 中的多步时间序列 LSTM 网络