首页 > 解决方案 > 如何在大不平衡体积上设置多类软骰子损失函数的平滑度?

问题描述

使用V-Net 论文中提出的平滑骰子函数:

公式

为 pytorch 中的多个类进行编码,并添加了平滑,以便它始终是可分的:

class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        assert inputs.shape == targets.shape, f"Shapes don't match {inputs.shape} != {targets.shape}"
        inputs = inputs[:,1:]                                                       # skip background class
        targets = targets[:,1:]                                                     # skip background class
        axes = tuple(range(2, len(inputs.shape)))                                   # sum over elements per sample and per class
        intersection = torch.sum(inputs * targets, axes)
        addition = torch.sum(torch.square(inputs) + torch.square(targets), axes)
        return 1 - torch.mean((2 * intersection + self.smooth) / (addition + self.smooth))

目标是大小为 52^3 = 140.608 的体积样本补丁。样本与具有 97% 背景、2% 肝脏和 1% 肿瘤的体积不重叠。因此,大多数补丁将完全是背景。骰子损失限制在 0 和 1 之间。

其他问题将平滑度设置为1.0或更低1e-7

假设一个补丁是完全背景的(大多数情况下),并且所有元素的预测是 0.01 (=1%) 肝脏和 0.01 肿瘤。将平滑度设置为 1.0 仍然会导致

1 - (2 x 0 + 1) / (0 + 140.608 x 0.0001 + 1) = 0.9336

即使预测准确,也接近最大损失。

这迫使网络将肝脏或肿瘤的可能性降低到 0.001 或更低。通常用作骰子损失的平滑度设置是什么?

标签: pytorchloss-functiondice

解决方案


推荐阅读