首页 > 解决方案 > Pytorch:如何找到多标签分类的准确性?

问题描述

我正在使用 vgg16,其中类数为 3,并且我可以为一个数据点预测多个标签。

vgg16 = models.vgg16(pretrained=True) vgg16.classifier[6]= nn.Linear(4096, 3)

使用损失函数:nn.BCEWithLogitsLoss()

在单个标签问题的情况下,我能够找到准确性,因为

 `images, labels = data
 images, labels = images.to(device), labels.to(device)
 labels = Encode(labels)
 outputs = vgg16(images)
 _, predicted = torch.max(outputs.data, 1)
 total += labels.size(0)
 correct += (predicted == labels).sum().item()
 acc = (100 * correct / total)`

如何找到多标签分类的准确性?

标签: deep-learningpytorch

解决方案


从您的问题来看,vgg16是返回原始 logits。因此,您可以执行以下操作:

labels = Encode(labels)  # torch.Size([N, C]) e.g. tensor([[1., 1., 1.]])
outputs = vgg16(images)  # torch.Size([N, C])
outputs = torch.sigmoid(outputs)  # torch.Size([N, C]) e.g. tensor([[0., 0.5, 0.]])
outputs[outputs >= 0.5] = 1
accuracy = (outputs == labels).sum()/(N*C)*100

推荐阅读