首页 > 解决方案 > 我的“焦点损失”功能是否存在逻辑错误?

问题描述

“焦点损失”的诞生是为了解决困难的样本。我在 pytorch 中用二元交叉熵包装了焦点损失函数:

class FocalLoss(nn.Module):

    def __init__(self, gamma=2):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
    
    def forward(self, pred, label):
        # label is not the one-hot
        true = torch.zeros_like(pred, dtype=torch.float)
        for i, j in enumerate(label):
            true[i, j] = 1.0
    
        loss = nn.BCEWithLogitsLoss(pred, true)
    
        pred_prob = torch.sigmoid(pred)  # sigmoid
        p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
        modulating_factor = (1.0 - p_t) ** self.gamma
        loss *= modulating_factor
    
        return loss.mean()

我在 cifar10 数据集中训练我的 resnet18 进行分类任务,并使用 nn.CrossEntropyLoss 来比较准确性。


但结果与我的预期相差甚远:CrossEntropyLoss 的准确率比我的focal loss 高5% 左右!任何人都可以在我上面的焦点损失代码中找到逻辑错误吗?

标签: pytorchloss

解决方案


推荐阅读