python - PyTorch如何实现断开连接(连接和对应的梯度被屏蔽)?
问题描述
我尝试实现下图。正如你所看到的,神经元不是完全连接的,即,权重被屏蔽了,它们对应的梯度也被屏蔽了。
import torch
import numpy as np
x = torch.rand((3, 1))
# tensor([[ 0.8525],
# [ 0.1509],
# [ 0.9724]])
weights = torch.rand((2, 3), requires_grad=True)
# tensor([[ 0.3240, 0.0792, 0.6858],
# [ 0.5248, 0.4565, 0.3625]])
mask = torch.Tensor([[0,1,0],[1,0,1]])
# tensor([[ 0., 1., 0.],
# [ 1., 0., 1.]])
mask_weights = weights * mask
# tensor([[ 0.0000, 0.0792, 0.0000],
# [ 0.5248, 0.0000, 0.3625]])
y = torch.mm(mask_weights, x)
# tensor([[ 0.0120],
# [ 0.7999]])
这个问题最初发布在Pytorch 论坛上。注意上面的方法
mask_weights = 权重 * 掩码
不适合,因为相应的梯度不为 0 。
请问有什么优雅的方法吗?
先感谢您。
解决方案
其实上面的方法是对的。断开连接基本上阻止了相应连接上的前馈和反向传播。换句话说,权重和梯度被掩盖了。有问题的代码揭示了第一个,而这个答案揭示了后者。
mask_weights.register_hook(print)
z = torch.Tensor([[1], [1]])
# tensor([[ 1.],
# [ 1.]])
out = (y-z).mean()
# tensor(-0.6595)
out.backward()
# tensor([[ 0.1920, 0.1757, 0.0046],
# [ 0.1920, 0.1757, 0.0046]])
weights.grad
# tensor([[ 0.0000, 0.1757, 0.0000],
# [ 0.1920, 0.0000, 0.0046]])
如您所见,权重的梯度会被自动屏蔽。
推荐阅读
- python - 如何修复类别视图调用为在模板中不起作用的 age_rate 类别
- sql-server - 如何修复“实体类型需要定义主键。” ASP.Net Core 中的错误
- javascript - javascript 代码中显示的 this.value 未定义消息
- javascript - 我想记录我点击链接的时间,并从我点击其他页面上的其他链接时结束记录
- ejs - 如何修复 EJS 语法错误:“参数列表后缺少 )”
- mysql - 如何从表中选择常用标题?
- latex - 如何将太宽太长的表格(一张表格)放入两页
- asp.net-mvc-5 - 在 json 结果中返回二进制图像显示序列化错误
- oracle - 如何在 soa 12c bpel 中使变量为空
- json - Join on value from object in jsonb array