首页 > 解决方案 > Pytorch:涉及端到端雅可比范数的自定义损失

问题描述

Pytorch 讨论区的交叉发帖

我想使用修改后的损失函数来训练网络,该损失函数既有典型的分类损失(例nn.CrossEntropyLoss​​如网络,\nabla_x f(x))。

我已经实现了一个可以成功学习的模型nn.CrossEntropyLoss。但是,当我尝试添加第二个损失函数(通过两次向后传递)时,我的训练循环会运行,但模型永远不会学习。此外,如果我计算端到端雅可比行列式,但不将其包含在损失函数中,则模型也永远不会学习。在高层次上,我的代码执行以下操作:

  1. 前向传递以yhat从输入中获取预测类别 ,x
  2. 称呼yhat.backward(torch.ones(appropriate shape), retain_graph=True)
  3. 雅可比范数 =x.grad.data.norm(2)
  4. 设置损失等于分类损失 + 标量系数 * 雅可比范数
  5. loss.backward()

我怀疑我误解了backward()运行两次时的工作原理,但我无法找到任何好的资源来澄清这一点。

产生一个工作示例需要太多,所以我试图提取相关代码:

def train_model(model, train_dataloader, optimizer, loss_fn, device=None):

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (batch_input, batch_target) in enumerate(train_dataloader):
        batch_input, batch_target = batch_input.to(device), batch_target.to(device)
        optimizer.zero_grad()
        batch_input.requires_grad_(True)
        model_batch_output = model(batch_input)
        loss = loss_fn(model_output=model_batch_output, model_input=batch_input, model=model, target=batch_target)
        train_loss += loss.item()  # sum up batch loss
        loss.backward()
        optimizer.step()

    def end_to_end_jacobian_loss(model_output, model_input):
        model_output.backward(
            torch.ones(*model_output.shape),
            retain_graph=True)
        jacobian = model_input.grad.data
        jacobian_norm = jacobian.norm(2)
        return jacobian_norm

编辑 1:我将以前的实现换成了.backward()toautograd.grad并且它显然有效!有什么不同?

    def end_to_end_jacobian_loss(model_output, model_input):
        jacobian = autograd.grad(
            outputs=model_output['penultimate_layer'],
            inputs=model_input,
            grad_outputs=torch.ones(*model_output['penultimate_layer'].shape),
            retain_graph=True,
            only_inputs=True)[0]
        jacobian_norm = jacobian.norm(2)
        return jacobian_norm

标签: pytorchloss-functionautodiff

解决方案


推荐阅读