pytorch - Torch:反向传播梯度而不更新变量
问题描述
在 pytorch 中,是否有一种有效的方法来反向传播梯度但不更新其相应的变量?在更新期间每次复制权重似乎太昂贵了。但是no_grad
&set_grad_enabled
不允许反向传播。
前任。以下似乎花费了太多时间,因为每次更新权重时都需要复制模型:
def __init__():
…
self.model = MyModel()
self.func1 = FuncModel1()
self.func2 = FuncModel2()
…
def trainstep(input):
f1 = self.func1(input)
f2 = self.func2(input)
…
# want to update weights in model & f1 with respect to loss1
loss1 = my_loss(model(f1), y1)
# don’t want to update weights in self.model with respect to loss2
# but want to update weights in f2 for loss2
copy_model = MyModel()
copy_model.load_state_dict(self.model.state_dict())
loss2 = my_loss(copy_model(f2), y2)
total_loss = loss1 + loss2
…
total_loss.backward()
optimizer.step()
解决方案
当loss.backward()
pytorch 通过整个计算图传播梯度时。
然而,backward()
函数本身并不更新任何权重,它只计算梯度。
更新是通过optimizer.step()
. 如果您想从更新中排除和的权重f1
,f2
您可以简单地使用
-initoptimizer
而不使用f1
and的参数f2
。
- 将学习率设置为f1
和f2
为零。
推荐阅读
- macos - 无法让 csshX 在 Mac OS Big Sur 上运行
- python - 在 Keras 序列模型中使用 tf.data.experimental.CsvDataset
- flutter - 加密 PDF 文档
- python - Selenium Python - 获取元素并单击每个元素
- c# - 未经授权时,ASP.NET Mvc 不会重定向到登录页面
- css - 圆圈中的不同图标显示为鸡蛋
- java - 使用 FBO 显示多通道着色器以通过 Android/OpenGLES2 进行高斯模糊时出现问题
- r - geom_bar 根据值使用 geom_point 颜色对数据进行分组
- vega-lite - 您可以在 Vega-lite 中动态设置配色方案和/或范围吗?
- wordpress - 恶意软件正在改变我
and <description> tags and I can't find the issue</h1> <div id="body"><p>When I go to my site, the title and description tags are fine. They are identified properly. I have viewed the source in chrome, firefox, and