首页 > 解决方案 > 输入批量大小与 CrossEntropyLoss 函数中的目标批量大小不匹配

问题描述

我一直在尝试在PyTorchFastAI类的帮助下从头开始构建一个模型来识别MNIST数据集中的手写数字。到目前为止,我一直在使用一个线性模型,它有 784 个输入(一个扁平的灰度 28 x 28 手写数字图像张量)和 10 个输出。DataLoader

simple_linear = torch.nn.Linear(784, 10)

我的训练数据是这样组织的:

train_x = torch.cat([stacked_zeros, stacked_ones, stacked_twos, stacked_threes, 
                     stacked_fours, stacked_fives, stacked_sixes, stacked_sevens, 
                     stacked_eights, stacked_nines]).view(-1, 28*28)

train_y = torch.nn.functional.one_hot(tensor([0] * len(zeros) + [1] * len(ones) + [2] * len(twos) + 
                 [3] * len(threes) + [4] * len(fours) + [5] * len(fives) + 
                 [6] * len(sixes) + [7] * len(sevens) + [8] * len(eights) + 
                 [9] * len(nines)).unsqueeze(1))

我的x变量具有形状[784],而y变量使用带有形状的单热编码向量进行标记[1, 10]

我根据研究选择的损失函数是torch.nn.CrossEntropyLoss,下面的代码给了我一个错误:

mnist_loss = torch.nn.CrossEntropyLoss()
mnist_loss(simple_linear(train_x[0]), train_y[0])

ValueError                                Traceback (most recent call last)
<ipython-input-245-03f54a6a43fb> in <module>()
----> 1 tst(simple_linear(x), y)

8 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2260     if input.size(0) != target.size(0):
   2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
   2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (1) to match target batch_size (10).

我试过重塑我的x变量和y变量,但我总是遇到类似的错误。为了使损失函数起作用,我的数据必须如何构造?

标签: pythonpytorch

解决方案


torch.nn.CrossEntropyLoss函数不将目标作为单热编码!

只需通过标签索引,所以基本上:

train_y = torch.tensor([0] * len(zeros) + [1] * len(ones) + [2] * len(twos) + 
                 [3] * len(threes) + [4] * len(fours) + [5] * len(fives) + 
                 [6] * len(sixes) + [7] * len(sevens) + [8] * len(eights) + 
                 [9] * len(nines)).unsqueeze(1)

这是一个建议,您可以编写如下所有内容:

dataset = [stacked_zeros, stacked_ones, stacked_twos, stacked_threes,
           stacked_fours, stacked_fives, stacked_sixes, stacked_sevens, 
           stacked_eights, stacked_nines]

train_x = torch.cat(dataset)
train_y = torch.tensor([[i]*d.size(0) for i, d in enumerate(dataset)])

推荐阅读