deep-learning - 防止更新卷积权重矩阵的特定元素
问题描述
我正在尝试将一个权重元素设置为 1,然后将其保持不变直到学习结束(防止它在下一个时期更新)。我知道我可以设置requires_grad = False
,但我只想对一个元素而不是所有元素进行此过程。
解决方案
您可以在您的上附加一个反向钩子,nn.Module
以便在反向传播期间您可以将感兴趣的元素覆盖到0
. 这确保了它的值永远不会改变,而不会阻止梯度反向传播到输入。
向后挂钩的新 API 是nn.Module.register_full_backward_hook
. 首先构造一个回调函数,用作层钩子:
def freeze_single(index):
def callback(module, grad_input, grad_output):
module.weight.grad.data[index] = 0
return callback
然后,我们可以将此钩子附加到任何nn.Module
. 例如,在这里我决定冻结[0, 1, 2, 1]
卷积层的组件:
>>> conv = nn.Conv2d(3, 1, 3)
>>> conv.weight.data[0, 1, 2, 1] = 1
>>> conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
一切都设置正确,让我们试试:
>>> x = torch.rand(1, 3, 10, 10, requires_grad=True)
>>> conv(x).mean().backward()
在这里我们可以验证分量的梯度[0, 1, 2, 1]
确实等于0
:
>>> conv.weight.grad
tensor([[[[0.4954, 0.4776, 0.4639],
[0.5179, 0.4992, 0.4856],
[0.5271, 0.5219, 0.5124]],
[[0.5367, 0.5035, 0.5009],
[0.5703, 0.5390, 0.5207],
[0.5422, 0.0000, 0.5109]], # <-
[[0.4937, 0.5150, 0.5200],
[0.4817, 0.5070, 0.5241],
[0.5039, 0.5295, 0.5445]]]])
您可以随时使用以下方法拆卸/重新连接挂钩:
>>> hook = conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
>>> hook.remove()
不要忘记如果你移除了钩子,当你更新你的权重时,那个组件的值会改变。如果您愿意,您将不得不将其重置为1
。否则,您可以实现第二个钩子——register_forward_pre_hook
这次是一个钩子——来处理它。
推荐阅读
- java - 活动和片段的不同onBackPressed
- css - Bootstrap 和 Ruby on Rails 的奇怪问题
- django - Django 在包含“注销”的 url 上中断
- c++ - 如何找到彼此为朋友的大小为 k(或更多)的人的序列?
- mysql - 当用户有很多评论时,为什么这个简单的 SELECT 需要 6 秒?
- python - 在 ET.iterpase() 中通过 Python 中的 XML.osm 更改属性值
- python-3.x - 在流数据上绘制滑动时间窗口
- python - Py:为什么 JSON 中的元素对每个动作都重复?
- php - 获取密码重置网址
- c# - Creating Exe/MSI for C# Windows Forms using Visual Studio 2017 Setup Project