python - 如何在pytorch中返回中间梯度(对于非叶节点)?
问题描述
我的问题是关于 pytorch 的语法register_hook
。
x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()
输出:
tensor([2.])
tensor([4.])
z
此代码段仅分别打印wrtx
和的梯度y
。
现在我(很可能是微不足道的)问题是如何返回中间渐变(而不仅仅是打印)?
更新:
看来调用retain_grad()
解决了叶节点的问题。前任。y.retain_grad()
.
但是,retain_grad
似乎没有解决非叶节点的问题。有什么建议么?
解决方案
我认为您可以使用这些挂钩将渐变存储在全局变量中:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
但是您很可能还需要记住计算这些梯度的相应张量。在这种情况下,我们使用 adict
而不是稍微扩展一下list
:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
例如,现在您可以y
简单地使用访问 tensor 的 gradgrads[y]
推荐阅读
- can-bus - 发送具有相同帧ID的数据帧时如何避免CAN总线上的冲突?
- javascript - 使用组件 onClick
- java - 无法解析项目 org.openjdk.jmh:jmh-core:jar:1.21 的依赖项
- r - 如何将数据框从求和扩展到单个观察
- python - 在 Python 中查找二维列表中的项目
- javascript - 如何使用 lodash 对多个进行分组
- python - 在 VSCode 中进入导入的标准模块
- gnuplot - 绘制多条曲线
- android - Dagger 2:在具有多个模块依赖项的组件中使用 Component.Builder
- java - 在后台下载 JSON 时出错