machine-learning - 如何在 Pytorch 的高级库中实现参数化元学习器?
问题描述
我想在论文OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING中实现 meta-lstm 元学习器,但我发现了问题。我发现如果不删除(似乎是这条关键线),我就无法使其工作:
至:
#self.param_groups = _copy.deepcopy(other.param_groups)
self.param_groups = other.param_groups
我在这里提供了一个非常简化的自包含实现类似的东西:
https://gist.github.com/renesax14/8499e0314351ea4199a17e494bff5c4d
但我将在此处复制粘贴以将讨论集中在一个地方:
# base on the paper "OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING": https://openreview.net/pdf?id=rJY0-Kcll
class EmptySimpleMetaLstm(Optimizer):
def __init__(self, params, trainable_opt_model, trainable_opt_state, *args, **kwargs):
defaults = {
'trainable_opt_model':trainable_opt_model,
'trainable_opt_state':trainable_opt_state,
'args':args,
'kwargs':kwargs
}
super().__init__(params, defaults)
class SimpleMetaLstm(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_model']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
# get gradient as "data"
g = g.detach() # gradients of gradients are not used (no hessians)
## very simplified version of meta-lstm meta-learner
input_metalstm = torch.stack([p, g, prev_lr.view(1,1)]).view(1,3) # [p, g, prev_lr] note it's missing loss, normalization etc. see original paper
lr = eta(input_metalstm).view(1)
fg = 1 - lr # learnable forget rate
## update suggested by meta-lstm meta-learner
p_new = fg*p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
higher.register_optim(EmptySimpleMetaLstm, SimpleMetaLstm)
def test_parametrized_inner_optimizer():
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
## training config
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
episodes = 5
nb_inner_train_steps = 5
## get base model
base_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('relu', nn.ReLU())
]))
## parametrization/mdl for the inner optimizer
opt_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(3,1, bias=False)), # 3 inputs 1 for parameter, 1 for gradient, 1 for previous lr
('sigmoid', nn.Sigmoid())
]))
## get outer optimizer (not differentiable nor trainable)
outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
for episode in range(episodes):
## get fake support & query data (from a single task and 1 data point)
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
## get differentiable & trainable (parametrized) inner optimizer
inner_opt = EmptySimpleMetaLstm(base_mdl.parameters(), trainable_opt_model={'eta': opt_mdl}, trainable_opt_state={'prev_lr': 0.9*torch.randn(1)})
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5)
fmodel.train()
# base/child model forward pass
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
# inner-opt update
diffopt.step(inner_loss)
## Evaluate on query set for current task
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
qry_loss.backward() # for memory efficient computation
## outer update
print(f'episode = {episode}')
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
outer_opt.step()
outer_opt.zero_grad()
if __name__ == '__main__':
test_parametrized_inner_optimizer()
print('Done \a')
"""
output when deep copy is uncommented (parametrized optimizer trains properly):
episode = 0
base_mdl.grad = tensor([[-0.0351]])
opt_mdl.grad = tensor([[0.0085, 0.0000, 0.0204]])
episode = 1
base_mdl.grad = tensor([[0.0311]])
opt_mdl.grad = tensor([[-0.0086, -0.0100, 0.0358]])
episode = 2
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = tensor([[0., 0., 0.]])
episode = 3
base_mdl.grad = tensor([[0.0066]])
opt_mdl.grad = tensor([[-0.0016, 0.0000, -0.0032]])
episode = 4
base_mdl.grad = tensor([[-0.0311]])
opt_mdl.grad = tensor([[0.0077, 0.0000, 0.0130]])
Done
when deep copy is on (paremeters of inner optimizer are not train, sad!):
episode = 0
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 1
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 2
base_mdl.grad = tensor([[0.0069]])
opt_mdl.grad = None
episode = 3
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 4
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
Done
The deep copy line in higher I am referencing:
self.param_groups = _copy.deepcopy(other.param_groups)
#self.param_groups = other.param_groups
"""
真正的解决方案
真正的解决方案是,如果我可以将任意字典传递给可微优化器,并且我可以用它做任何我想做的事情。
更新:
也许这可以通过覆盖来实现:
override(可选)– 字典映射优化器设置(即那些将传递给优化器构造函数或在参数组中提供的那些)到覆盖值的单例列表,或长度等于参数数量的覆盖值列表团体。如果为关键字提供了单个覆盖,则它将用于所有参数组。如果提供了列表,则列表的第 i 个元素将覆盖第 i 个参数组中的相应设置。这允许将需要梯度的张量传递给可微优化器,以用作优化器设置。
不适用于覆盖:
Exception has occurred: ValueError
Mismatch between the number of override tensors for optimizer parameter trainable_opt_model and the number of parameter groups.
seems like it checks that these lengths match...
def _apply_override(self, override: _OverrideType) -> None:
for k, v in override.items():
# Sanity check
if (len(v) != 1) and (len(v) != len(self.param_groups)
我想这就是我所需要的:
inner_opt = EmptySimpleMetaLstm( base_mdl.parameters() )
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
交叉发布:
- 更高版本的 git 问题:https ://github.com/facebookresearch/higher/issues/62
- https://www.reddit.com/r/pytorch/comments/hbp1n5/how_does_one_implemented_a_parametrized/?
- https://discuss.pytorch.org/t/how-does-one-implemented-a-parametrized-meta-learner-in-pytorchs-higher-library/85988
有关的:
解决方案
不是 100% 确定这是您想要的,但是如果您可以将内部循环训练表达为一系列,例如 SGD(或另一个简单的优化器)更新,其参数例如lr
由您想要训练的 NN 计算(例如 LSTM)然后你可以在override
每次调用diffopt.step
. 这是一个玩具示例:
def test_parametrized_inner_optimizer():
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
import higher
## training config
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
episodes = 5
nb_inner_train_steps = 5
## get base model
base_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('relu', nn.LeakyReLU()) # using LeakyReLU here otherwise too often getting 0 gradients on everything :)
]))
## parametrization/mdl for the inner optimizer
opt_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(3,1, bias=False)), # 3 inputs [p, g, prev_lr] 1 for parameter, 1 for gradient, 1 for previous lr
('sigmoid', nn.Sigmoid())
]))
##
par = torch.tensor([0.2], requires_grad=True) # is that what you mean by parameters?
## get outer optimizer (not differentiable nor trainable)
outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
for episode in range(episodes):
## get fake support & query data (from a single task and 1 data point)
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
## get differentiable reference optimizer
inner_opt = torch.optim.SGD(base_mdl.parameters(), lr=0.1) # lr will be overriden
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
prev_lr = torch.tensor([0.1]) # or whatever supposed to be used at first step as prev_lr input for opt_mdl
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5)
fmodel.train()
# base/child model forward pass
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
# latest_grad is same thing as the one we will use for inner optimizer step, but we
# need it before the step because we will also produce lr based on it using opt_mdl
latest_grad = torch.autograd.grad(
inner_loss, fmodel.parameters(), retain_graph=True, create_graph=True) # create_graph makes the gradient returned not just separate tensor, but something gradients can propagate through
latest_grad = latest_grad[0].reshape(-1)
# inner-opt update
lr_as_output_of_another_model_can_be_lstm_or_whatever = opt_mdl(torch.stack((par, latest_grad, prev_lr)).reshape(1, -1))[0]
diffopt.step(inner_loss, override={'lr': lr_as_output_of_another_model_can_be_lstm_or_whatever})
prev_lr = lr_as_output_of_another_model_can_be_lstm_or_whatever
## Evaluate on query set for current task
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
qry_loss.backward() # for memory efficient computation
## outer update
print(f'episode = {episode}')
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
outer_opt.step()
outer_opt.zero_grad()
test_parametrized_inner_optimizer()
申请时注意create_graph=True
in addition to的retain_graph=True
使用torch.autograd.grad
。如果我们不这样做,那么我们就不会通过梯度输入将梯度传播到这里计算 lr 的 NN(参考文章的第 3.3.1 段建议简化计算而不在那里传播梯度,但这里可以通过以下方式进行选择提供create_graph=True
或不提供)。
推荐阅读
- swift - 对于每个未正确显示的空数组
- metafor - 与图文本重叠的森林图
- arrays - 使用自定义对象内联初始化和填充数组
- c# - 使用 DataTable.Compute() 方法比较 C# Visual Studio 中的两个变量
- python - 连接没有唯一键的列并在新列中写下非唯一组合
- python - 无法使用 pymysql 将数据文件加载到 MySQL - 找不到文件
- windows - CListCtrl (MFC) 将 MouseWheel 事件转换为“所选项目已更改”通知
- javascript - 无法使用 Javascript 在 Google 表单上设置输入值
- spring - Spring @Transactional - 通过 AspectJ 同步
- sql-server - Powershell将管道返回值存储到变量