首页 > 解决方案 > PyToch:ValueError:预期输入batch_size(256)与目标batch_size(128)匹配

问题描述

我在使用 pytorch 训练 BiLSTM 词性标注器时遇到了 ValueError。ValueError:预期输入 batch_size (256) 与目标 batch_size (128) 匹配。

def train(model, iterator, optimizer, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        text = batch.p
        tags = batch.t
        
        optimizer.zero_grad()
        
        #text = [sent len, batch size]
        
        predictions = model(text)
        
        #predictions = [sent len, batch size, output dim]
        #tags = [sent len, batch size]
        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        
        #predictions = [sent len * batch size, output dim]
        #tags = [sent len * batch size]
        
        loss = criterion(predictions, tags)
                
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)



def evaluate(model, iterator, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch.p
            tags = batch.t
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

class BiLSTMPOSTagger(nn.Module):
        def __init__(self, 
                     input_dim, 
                     embedding_dim, 
                     hidden_dim, 
                     output_dim, 
                     n_layers, 
                     bidirectional, 
                     dropout, 
                     pad_idx):
            
            super().__init__()
            
            self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
            
            self.lstm = nn.LSTM(embedding_dim, 
                                hidden_dim, 
                                num_layers = n_layers, 
                                bidirectional = bidirectional,
                                dropout = dropout if n_layers > 1 else 0)
            
            self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
            
            self.dropout = nn.Dropout(dropout)
            
        def forward(self, text):
            embedded = self.dropout(self.embedding(text))
            outputs, (hidden, cell) = self.lstm(embedded)
            predictions = self.fc(self.dropout(outputs))        
            return predictions

…………………………………………………………………………………………………………………………………………………………………… …………………………………………………………………………………………………………………………………………………………………… ...........

INPUT_DIM = len(POS.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = len(TAG.vocab)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
PAD_IDX = POS.vocab.stoi[POS.pad_token]

print(INPUT_DIM)  #output 22147
print(OUTPUT_DIM) #output 42

model = BiLSTMPOSTagger(INPUT_DIM, 
                        EMBEDDING_DIM, 
                        HIDDEN_DIM, 
                        OUTPUT_DIM, 
                        N_LAYERS, 
                        BIDIRECTIONAL, 
                        DROPOUT, 
                        PAD_IDX)

…………………………………………………………………………………………………………………………………………………………………… …………………………………………………………………………………………………………………………………………………………………… ...........

N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')





ValueError                                Traceback (most recent call last)
<ipython-input-55-83bf30366feb> in <module>()
      7     start_time = time.time()
      8 
----> 9     train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
     10     valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
     11 

4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2260     if input.size(0) != target.size(0):
   2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
   2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (256) to match target batch_size (128).

标签: pythonneural-networkpytorchlstmpart-of-speech

解决方案


(继续评论)

我猜你的批量大小等于 128(它没有定义),对吧?LSTM 输出每个时间步的输出列表。但是对于分类,您通常只需要最后一个。所以第一个维度outputs是你的序列长度,在你的情况下似乎是 2。当你应用.view这两个乘以你的批量大小(128)然后它得到 256。所以在你的 lstm 层之后你需要采取最后输出 output-sequence outputs。像这样:

def forward(self, text):
    embedded = self.dropout(self.embedding(text))
    outputs, (hidden, cell) = self.lstm(embedded)
    
    # take last output
    outputs = outputs.reshape(batch_size, sequence_size, hidden_size)
    outputs = outputs[:, -1]

    predictions = self.fc(self.dropout(outputs))        
    return predictions

推荐阅读