首页 > 解决方案 > 为什么不累积查询损失,然后使用 Pytorch 和更高版本在 MAML 中求导?

问题描述

在进行 MAML(与模型无关的元学习)时,有两种方法可以进行内循环:


def inner_loop1():
    n_inner_iter = 5
    inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

    qry_losses = []
    qry_accs = []
    meta_opt.zero_grad()
    for i in range(task_num):
        with higher.innerloop_ctx(
            net, inner_opt, copy_initial_weights=False
        ) as (fnet, diffopt):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            # higher is able to automatically keep copies of
            # your network's parameters as they are being updated.
            for _ in range(n_inner_iter):
                spt_logits = fnet(x_spt[i])
                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                diffopt.step(spt_loss)

            # The final set of adapted parameters will induce some
            # final loss and accuracy on the query dataset.
            # These will be used to update the model's meta-parameters.
            qry_logits = fnet(x_qry[i])
            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
            qry_losses.append(qry_loss.detach())
            qry_acc = (qry_logits.argmax(
                dim=1) == y_qry[i]).sum().item() / querysz
            qry_accs.append(qry_acc)

            # Update the model's meta-parameters to optimize the query
            # losses across all of the tasks sampled in this batch.
            # This unrolls through the gradient steps.
            qry_loss.backward()

    meta_opt.step()
    qry_losses = sum(qry_losses) / task_num
    qry_accs = 100. * sum(qry_accs) / task_num
    i = epoch + float(batch_idx) / n_train_iter
    iter_time = time.time() - start_time

def inner_loop2():
    n_inner_iter = 5
    inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

    qry_losses = []
    qry_accs = []
    meta_opt.zero_grad()
    meta_loss = 0
    for i in range(task_num):
        with higher.innerloop_ctx(
            net, inner_opt, copy_initial_weights=False
        ) as (fnet, diffopt):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            # higher is able to automatically keep copies of
            # your network's parameters as they are being updated.
            for _ in range(n_inner_iter):
                spt_logits = fnet(x_spt[i])
                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                diffopt.step(spt_loss)

            # The final set of adapted parameters will induce some
            # final loss and accuracy on the query dataset.
            # These will be used to update the model's meta-parameters.
            qry_logits = fnet(x_qry[i])
            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
            qry_losses.append(qry_loss.detach())
            qry_acc = (qry_logits.argmax(
                dim=1) == y_qry[i]).sum().item() / querysz
            qry_accs.append(qry_acc)

            # Update the model's meta-parameters to optimize the query
            # losses across all of the tasks sampled in this batch.
            # This unrolls through the gradient steps.
            #qry_loss.backward()
            meta_loss += qry_loss

    meta_loss.backward()
    meta_opt.step()
    qry_accs = 100. * sum(qry_accs) / task_num
    i = epoch + float(batch_idx) / n_train_iter
    iter_time = time.time() - start_time

它们真的等效吗?


交叉发布:

标签: machine-learningpytorch

解决方案


唯一的区别是,在第二种方法中,您必须在内存中保留更多的东西 - 在您调用之前,backward您将拥有每次迭代的所有展开参数fnet.parameters(time=T)(以及中间计算张量)task_num作为聚合图形的一部分meta_loss。如果您调用backward每个任务,那么您只需要为一个任务保留完整的展开参数(和图表的其他部分)。

所以回答你的问题的标题:因为在这种情况下,内存占用要task_num大几倍。

简而言之,您所做的类似于比较loopA(N)loopB(N)以下代码。这里loopA将获得尽可能多的内存和足够大的 OOM N,而loopB对于任何大的内存将使用大约相同数量的内存N

import torch
import numpy as np
a = 0
np.random.seed(1)
v = torch.tensor(np.random.randn(1000000))
y = torch.tensor(np.random.randn(1000000))
x = torch.zeros(1000000, requires_grad=True)

def loopA(N=1000):
    a = 0
    for i in range(N):
        a += ((x * v - y)**2).sum()
    a.backward()

def loopB(N=1000):
    for i in range(N):
        a = ((x * v - y)**2).sum()
        a.backward()

关于归一化 - 两种方法是等效的(可能达到数值精度):如果您首先总结各个损失,然后除以task_num,然后最后调用,backward那么您将有效地计算d((Loss_1 + ... + Loss_{task_num})/task_num) / dw(其中 w 是元优化器适合的权重之一)。另一方面,如果您要求backward将每个损失除以task_num您将得到d(Loss_1/task_num)/dw + ... + d(Loss_{task_num}/task_num)/dw相同的结果,因为采用梯度操作是线性的。因此,在这两种情况下,您的元优化器步骤都将从几乎相同的渐变开始。


推荐阅读