首页 > 解决方案 > 带有 3d 输入的 Pytorch 交叉熵损失

问题描述

我有一个输出 3D 张量 size 的网络(batch_size, max_len, num_classes)。我的真实情况是形状(batch_size, max_len)。如果我确实对标签执行 one-hot 编码,它将具有形状,(batch_size, max_len, num_classes)即 in 中的值max_len是 range 中的整数[0, num_classes]。由于原始代码太长,我写了一个更简单的版本来重现原始错误。

criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)

pred 和 label 的形状分别torch.Size([32, 350, 1000])torch.Size([32, 350])

遇到的错误是

ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])

如果我用 one-hot 编码标签来计算损失

x = nn.functional.one_hot(label)
criterion(pred, x)

它会抛出以下错误

ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])

标签: pythonneural-networkpytorchcross-entropy

解决方案


Pytorch 文档中,CrossEntropyLoss期望其输入的形状为(N, C, ...),因此第二维始终是类的数量。如果您重塑preds为 size ,您的代码应该可以工作(batch_size, num_classes, max_len)


推荐阅读