machine-learning - 为什么不累积查询损失,然后使用 Pytorch 和更高版本在 MAML 中求导?
问题描述
在进行 MAML(与模型无关的元学习)时,有两种方法可以进行内循环:
def inner_loop1():
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
def inner_loop2():
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
meta_loss = 0
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
#qry_loss.backward()
meta_loss += qry_loss
meta_loss.backward()
meta_opt.step()
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
它们真的等效吗?
交叉发布:
解决方案
唯一的区别是,在第二种方法中,您必须在内存中保留更多的东西 - 在您调用之前,backward
您将拥有每次迭代的所有展开参数fnet.parameters(time=T)
(以及中间计算张量)task_num
作为聚合图形的一部分meta_loss
。如果您调用backward
每个任务,那么您只需要为一个任务保留完整的展开参数(和图表的其他部分)。
所以回答你的问题的标题:因为在这种情况下,内存占用要task_num
大几倍。
简而言之,您所做的类似于比较loopA(N)
和loopB(N)
以下代码。这里loopA
将获得尽可能多的内存和足够大的 OOM N
,而loopB
对于任何大的内存将使用大约相同数量的内存N
:
import torch
import numpy as np
a = 0
np.random.seed(1)
v = torch.tensor(np.random.randn(1000000))
y = torch.tensor(np.random.randn(1000000))
x = torch.zeros(1000000, requires_grad=True)
def loopA(N=1000):
a = 0
for i in range(N):
a += ((x * v - y)**2).sum()
a.backward()
def loopB(N=1000):
for i in range(N):
a = ((x * v - y)**2).sum()
a.backward()
关于归一化 - 两种方法是等效的(可能达到数值精度):如果您首先总结各个损失,然后除以task_num
,然后最后调用,backward
那么您将有效地计算d((Loss_1 + ... + Loss_{task_num})/task_num) / dw
(其中 w 是元优化器适合的权重之一)。另一方面,如果您要求backward
将每个损失除以task_num
您将得到d(Loss_1/task_num)/dw + ... + d(Loss_{task_num}/task_num)/dw
相同的结果,因为采用梯度操作是线性的。因此,在这两种情况下,您的元优化器步骤都将从几乎相同的渐变开始。
推荐阅读
- android - 从另一个 Fragment 在 Fragment 的选项卡视图中设置选项卡位置
- php - Laravel 登录后将用户重定向到原始目的地
- r - 两个样本(不等方差)韦尔奇 t 检验的统计功效和样本量确定?
- azure - Verizon 的 Azure CDN 目前不支持 CDN 和源之间的自动 SNI 解析
- python - Heroku 在一个测功机中运行两个工人
- html - HTML 如何构建表单组和引导类以捕获单个表单数据,传递给 Python Flask
- apache-flink - Flink:Key Group 91 不属于本地范围
- typo3 - TYPO3 内联元素外观配置部分不起作用
- objective-c - 使 UIAlertAction 句柄在 Objective-C 中工作
- c# - 请帮助/指导将打破 1 小时间隔的查询拆分