首页 > 解决方案 > 与交叉熵损失相比,为什么蒸馏损失不收敛(减少)?

问题描述

我在持续学习(又名增量学习)模型中使用交叉熵 && 蒸馏损失。但是,蒸馏损失不会收敛,而交叉熵损失会收敛。

这是我的蒸馏损失和训练部分的代码。

蒸馏损失代码:

def loss_fn_distillation(outputs, soft_labels, temperature, current_step, total_step, total_label):
    current_label = (total_label / total_step) * (current_step + 1)
    previous_label = (total_label / total_step) * current_step

    soft_labels = V(soft_labels.data, requires_grad=False).cuda()
    soft_labels = torch.softmax(soft_labels / temperature, dim=1)

     outputs = F.log_softmax(outputs[:,:-int(current_label-previous_label)]/temperature, dim = 1)

     distill_loss = torch.sum(outputs * soft_labels, dim=1, keepdim=False)
     distill_loss = -torch.mean(distill_loss, dim=0, keepdim=False)


      return V(distill_loss, requires_grad=True).cuda()

训练部分代码:

    outputs = net(inputs)
    ce_loss = criterion(outputs, targets)

    if(i>0) :

        soft_label = previous_net(inputs)

        distill_loss = loss_fn_distillation(outputs=outputs, soft_labels=soft_label, temperature=2,
                                            current_step=i, total_step=step, total_label=number_label)
        print(ce_loss, distill_loss)
        loss = distill_loss + ce_loss


    else :
        loss = ce_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch 的交叉熵损失和蒸馏损失的结果:

在此处输入图像描述

我会很感激任何反馈。谢谢你。

标签: pythondeep-learningpytorch

解决方案


推荐阅读