python - 输入批量大小与 CrossEntropyLoss 函数中的目标批量大小不匹配
问题描述
我一直在尝试在PyTorch和FastAI类的帮助下从头开始构建一个模型来识别MNIST数据集中的手写数字。到目前为止,我一直在使用一个线性模型,它有 784 个输入(一个扁平的灰度 28 x 28 手写数字图像张量)和 10 个输出。DataLoader
simple_linear = torch.nn.Linear(784, 10)
我的训练数据是这样组织的:
train_x = torch.cat([stacked_zeros, stacked_ones, stacked_twos, stacked_threes,
stacked_fours, stacked_fives, stacked_sixes, stacked_sevens,
stacked_eights, stacked_nines]).view(-1, 28*28)
train_y = torch.nn.functional.one_hot(tensor([0] * len(zeros) + [1] * len(ones) + [2] * len(twos) +
[3] * len(threes) + [4] * len(fours) + [5] * len(fives) +
[6] * len(sixes) + [7] * len(sevens) + [8] * len(eights) +
[9] * len(nines)).unsqueeze(1))
我的x
变量具有形状[784]
,而y
变量使用带有形状的单热编码向量进行标记[1, 10]
。
我根据研究选择的损失函数是torch.nn.CrossEntropyLoss
,下面的代码给了我一个错误:
mnist_loss = torch.nn.CrossEntropyLoss()
mnist_loss(simple_linear(train_x[0]), train_y[0])
ValueError Traceback (most recent call last)
<ipython-input-245-03f54a6a43fb> in <module>()
----> 1 tst(simple_linear(x), y)
8 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2260 if input.size(0) != target.size(0):
2261 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (1) to match target batch_size (10).
我试过重塑我的x
变量和y
变量,但我总是遇到类似的错误。为了使损失函数起作用,我的数据必须如何构造?
解决方案
该torch.nn.CrossEntropyLoss
函数不将目标作为单热编码!
只需通过标签索引,所以基本上:
train_y = torch.tensor([0] * len(zeros) + [1] * len(ones) + [2] * len(twos) +
[3] * len(threes) + [4] * len(fours) + [5] * len(fives) +
[6] * len(sixes) + [7] * len(sevens) + [8] * len(eights) +
[9] * len(nines)).unsqueeze(1)
这是一个建议,您可以编写如下所有内容:
dataset = [stacked_zeros, stacked_ones, stacked_twos, stacked_threes,
stacked_fours, stacked_fives, stacked_sixes, stacked_sevens,
stacked_eights, stacked_nines]
train_x = torch.cat(dataset)
train_y = torch.tensor([[i]*d.size(0) for i, d in enumerate(dataset)])
推荐阅读
- python - mplcursors 与散点图端点的交互性
- docker-compose - 当 dns 查找返回服务器域的多个 IP 地址时的 haproxy 行为
- dictionary - Ansible 过滤器将特定键从字典中提取到另一个字典中
- spring-boot - 该值不是通过注入加载的
- unity3d - 为什么我的玩家没有受到敌人的伤害?
- typescript - 如果测试用例并行运行,如何在量角器中创建报告?
- c# - 如何让角色转向相机的方向?[Unity3D]
- python - 当我通过 python 中的文件对话框打开图像时,为什么会收到“UnicodeDecodeError”?
- javascript - 在 JavaScript 中创建迭代器时,数组未传递给 next() 方法
- android - 错误警告:应避免使用 flatDirs,因为它不支持任何元数据格式