首页 > 解决方案 > 如何在 PyTorch 中计算自举交叉熵损失?

问题描述

我读过一些论文,它们使用一种叫做“自举交叉熵损失”的东西来训练他们的分割网络。这个想法是只关注最难的 k%(比如 15%)像素,以提高学习性能,尤其是当简单像素占主导地位时。

目前,我正在使用标准交叉熵:

loss = F.binary_cross_entropy(mask, gt)

如何在 PyTorch 中有效地将其转换为引导版本?

标签: deep-learningneural-networkpytorchloss-function

解决方案


通常我们还会在损失中添加一个“热身”期,以便网络可以学习首先适应容易的区域并过渡到较难的区域。

这个实现从k=10020000 次迭代开始并持续,然后线性衰减到k=15另外 50000 次迭代。

class BootstrappedCE(nn.Module):
    def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
        super().__init__()

        self.start_warm = start_warm
        self.end_warm = end_warm
        self.top_p = top_p

    def forward(self, input, target, it):
        if it < self.start_warm:
            return F.cross_entropy(input, target), 1.0

        raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
        num_pixels = raw_loss.numel()

        if it > self.end_warm:
            this_p = self.top_p
        else:
            this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
        loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
        return loss.mean(), this_p

推荐阅读