首页 > 解决方案 > 使用 PyTorch 高级库执行 MAML 时,何时应该调用 .eval() 和 .train()?

问题描述

我正在浏览omniglot maml 示例,看到他们net.train()测试代码的顶部有。这似乎是一个错误,因为这意味着元测试中每个任务的统计数据是共享的:

def test(db, net, device, epoch, log):
    # Crucially in our testing procedure here, we do *not* fine-tune
    # the model during testing for simplicity.
    # Most research papers using MAML for this task do an extra
    # stage of fine-tuning here that should be added if you are
    # adapting this code for research.
    net.train()
    n_test_iter = db.x_test.shape[0] // db.batchsz

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')


        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        for i in range(task_num):
            with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The query loss and acc induced by these parameters.
                qry_logits = fnet(x_qry[i]).detach()
                qry_loss = F.cross_entropy(
                    qry_logits, y_qry[i], reduction='none')
                qry_losses.append(qry_loss.detach())
                qry_accs.append(
                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())

    qry_losses = torch.cat(qry_losses).mean().item()
    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
    print(
        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
    )
    log.append({
        'epoch': epoch + 1,
        'loss': qry_losses,
        'acc': qry_accs,
        'mode': 'test',
        'time': time.time(),
    })

但是,每当我进行 eval 时,我都会发现我的 MAML 模型存在差异(尽管我的测试是在 mini-imagenet 上进行的):

>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5942, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)
eval_loss=0.9859228551387786, eval_acc=0.5907692521810531
args.meta_learner.lr_inner=0.01
==== in forward2
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(171440.6875, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(208426.0156, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(17067344., grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(40371.8125, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.0911e+11, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(21.3515, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(5.4257e+13, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(128.9109, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(3994.7734, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1682896., grad_fn=<NormBackward1>)
eval_loss_sanity=nan, eval_acc_santiy=0.20000000298023224

那么人们应该怎么做才能避免这种分歧呢?

笔记:


有关的:

标签: machine-learningdeep-learningpytorchhighermeta-learning

解决方案


TLDR:使用mdl.train(),因为它使用批量统计(但推理将不再是确定性的)。您可能不想mdl.eval()在元学习中使用。


BN 预期行为:

  • 重要的是,在推理(评估/测试)running_mean 期间,使用了running_std - 这是从训练中计算出来的(因为他们想要确定性的输出并使用人口统计数据的估计)。
  • 在训练期间,使用批量统计数据,但使用运行平均值估计总体统计数据。我认为在训练期间使用 batch_stats 的原因是引入了使训练正规化的噪声(噪声鲁棒性)
  • 在元学习中,我认为在测试期间使用批处理统计是最好的(而不是计算运行均值),因为无论如何我们都应该看到新的 /tasksdistribution。我们付出的代价是确定性的丧失。出于好奇,使用从 meta-trian 估计的人口统计数据的准确性可能会很有趣。

这可能就是为什么我在使用mdl.train().

所以只要确保你使用mdl.train()(因为它使用批处理统计https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d)但是要么是新的运行统计作弊'不保存或以后使用。


推荐阅读