首页 > 解决方案 > 如何在pytorch中返回中间梯度(对于非叶节点)?

问题描述

我的问题是关于 pytorch 的语法register_hook

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y

x.register_hook(print)
y.register_hook(print)

z.backward()

输出:

tensor([2.])
tensor([4.])

z此代码段仅分别打印wrtx和的梯度y

现在我(很可能是微不足道的)问题是如何返回中间渐变(而不仅仅是打印)?

更新:

看来调用retain_grad()解决了叶节点的问题。前任。y.retain_grad().

但是,retain_grad似乎没有解决非叶节点的问题。有什么建议么?

标签: pythongradientpytorch

解决方案


我认为您可以使用这些挂钩将渐变存储在全局变量中:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

但是您很可能还需要记住计算这些梯度的相应张量。在这种情况下,我们使用 adict而不是稍微扩展一下list

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

例如,现在您可以y简单地使用访问 tensor 的 gradgrads[y]


推荐阅读