首页 > 解决方案 > 防止更新卷积权重矩阵的特定元素

问题描述

我正在尝试将一个权重元素设置为 1,然后将其保持不变直到学习结束(防止它在下一个时期更新)。我知道我可以设置requires_grad = False,但我只想对一个元素而不是所有元素进行此过程。

标签: deep-learningneural-networkpytorchconvolution

解决方案


您可以在您的上附加一个反向钩子,nn.Module以便在反向传播期间您可以将感兴趣的元素覆盖到0. 这确保了它的值永远不会改变,而不会阻止梯度反向传播到输入。

向后挂钩的新 API 是nn.Module.register_full_backward_hook. 首先构造一个回调函数,用作层钩子:

def freeze_single(index):
    def callback(module, grad_input, grad_output):
        module.weight.grad.data[index] = 0
    return callback

然后,我们可以将此钩子附加到任何nn.Module. 例如,在这里我决定冻结[0, 1, 2, 1]卷积层的组件:

>>> conv = nn.Conv2d(3, 1, 3)
>>> conv.weight.data[0, 1, 2, 1] = 1

>>> conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))

一切都设置正确,让我们试试:

>>> x = torch.rand(1, 3, 10, 10, requires_grad=True)
>>> conv(x).mean().backward()

在这里我们可以验证分量的梯度[0, 1, 2, 1]确实等于0

>>> conv.weight.grad
tensor([[[[0.4954, 0.4776, 0.4639],
          [0.5179, 0.4992, 0.4856],
          [0.5271, 0.5219, 0.5124]],

         [[0.5367, 0.5035, 0.5009],
          [0.5703, 0.5390, 0.5207],
          [0.5422, 0.0000, 0.5109]], # <-

         [[0.4937, 0.5150, 0.5200],
          [0.4817, 0.5070, 0.5241],
          [0.5039, 0.5295, 0.5445]]]])

您可以随时使用以下方法拆卸/重新连接挂钩:

>>> hook = conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
>>> hook.remove()

不要忘记如果你移除了钩子,当你更新你的权重时,那个组件的值会改变。如果您愿意,您将不得不将其重置为1。否则,您可以实现第二个钩子——register_forward_pre_hook这次是一个钩子——来处理它。


推荐阅读