首页 > 解决方案 > PyTorch如何实现断开连接(连接和对应的梯度被屏蔽)?

问题描述

我尝试实现下图。正如你所看到的,神经元不是完全连接的,即,权重被屏蔽了,它们对应的梯度也被屏蔽了。

在此处输入图像描述

import torch
import numpy as np

x = torch.rand((3, 1))
# tensor([[ 0.8525],
#         [ 0.1509],
#         [ 0.9724]])

weights = torch.rand((2, 3), requires_grad=True)
# tensor([[ 0.3240,  0.0792,  0.6858],
#         [ 0.5248,  0.4565,  0.3625]])

mask = torch.Tensor([[0,1,0],[1,0,1]])
# tensor([[ 0.,  1.,  0.],
#         [ 1.,  0.,  1.]])

mask_weights = weights * mask
# tensor([[ 0.0000,  0.0792,  0.0000],
#         [ 0.5248,  0.0000,  0.3625]])

y = torch.mm(mask_weights, x)
# tensor([[ 0.0120],
#         [ 0.7999]])

这个问题最初发布在Pytorch 论坛上。注意上面的方法

mask_weights = 权重 * 掩码

不适合,因为相应的梯度不为 0

请问有什么优雅的方法吗?

先感谢您。

标签: pythonneural-networkpytorch

解决方案


其实上面的方法是对的。断开连接基本上阻止了相应连接上的前馈和反向传播。换句话说,权重和梯度被掩盖了。有问题的代码揭示了第一个,而这个答案揭示了后者。

mask_weights.register_hook(print)

z = torch.Tensor([[1], [1]])
# tensor([[ 1.],
#         [ 1.]])

out = (y-z).mean()
# tensor(-0.6595)

out.backward()
# tensor([[ 0.1920,  0.1757,  0.0046],
#         [ 0.1920,  0.1757,  0.0046]])

weights.grad
# tensor([[ 0.0000,  0.1757,  0.0000],
#         [ 0.1920,  0.0000,  0.0046]])

如您所见,权重的梯度会被自动屏蔽。


推荐阅读