python - Pytorch 对输入而不是输出求和雅可比
问题描述
假设我有一个Y
(直接或间接)从 tensor 计算的张量X
。
通常,当我申请时torch.autograd.grad(Y, X, grad_outputs=torch.ones_like(Y))
,我会得到一个与 形状相同的渐变蒙版X
。Y
这个掩码实际上是wrt元素梯度的加权和X
。
是否有可能得到一个形状相同的渐变蒙版Y
,其中每个元素mask[i][j]
都是Y[i][j]
wrt的渐变之和X
?
这等效于对J(Y,X)
的维度X
而不是 的维度求和雅可比行列式Y
。
>>> X = torch.eye(2)
>>> X.requires_grad_()
# X = [1 0]
# [0 1]
>>> Y = torch.sum(X*X, dim=0)
# Y = [1, 1]
>>> torch.autograd.grad(Y, X, grad_outputs=torch.ones_like(Y), retain_graph=True)
(tensor([[2., 0.],
[0., 2.]]),)
但相反,我想要:
# [2, 2]
因为torch.sum(torch.autograd.grad(Y[0],X)
等于2
也torch.sum(torch.autograd.grad(Y[1],X)
等于2
。
Y
计算wrt的雅可比行列式很容易X
,只需对 的维度求和即可X
。然而,这在记忆方面是不可行的,因为我使用的函数是具有大量输入和输出的神经网络。
单独计算每个梯度(正如我在评论中所做的那样)也是非常不可取的,因为这太慢了。
解决方案
如果您每晚运行 pytorch,https://github.com/pytorch/pytorch/issues/10223将部分实现,并且应该为大多数简单图形执行您想要的操作。您也可以尝试使用https://j-towns.github.io/2017/06/12/A-new-trick.html中描述的技巧。
编辑:它看起来像https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html#torch.autograd.functional.jvp为您实现了倒退技巧。所以你可以这样做:
from torch.autograd.functional import jvp
X = torch.eye(2)
X.requires_grad_()
def build_Y(x):
return torch.sum(x*x, dim=0)
print(jvp(build_Y, X, torch.ones(X.shape))[1])
推荐阅读
- c - 如何将带有分隔符的文件行中的单词存储到结构中
- typescript - 在 Typescript 声明中,如何让一个值依赖于另一个值的类型?
- python - Selenium - 'list' 对象没有属性 'screenshot_as_png'
- android - minSdkVersion > 21 是否需要 vectorDrawables.useSupportLibrary 标志
- html - 如何连接“<”和“>”字符以构建从 postgresql 通过 XML 和 XSL 到 HTML 的 img 标签
- spring - Spring Boot websocket RabbitMQ STOMP 中继代理在没有 TCP 连接的情况下从实例发送到客户端时无法发送消息
- powershell - 依次运行 2 个 Powershell 脚本
- r - 如何在 R 中使循环图动态/交互?
- python - Numpy 数组以索引为值的字典
- php - docker-compose:为什么我的 webapp 无法连接到 mysql?