deep-learning - 分离 Pytorch 中关于部分损失的中间模块
问题描述
假设我有以下前向传递,导致两个单独的损失:
forward(self, input)
x = self.layer1(input)
y = self.layer2(x)
z = self.layer3(y)
return y, z
然后我们计算 loss1(y) 和 loss2(z)。然后我们可以loss = loss1 + loss2
使用单个优化器进行优化。
但是我有两个警告:(1)我希望仅针对 layer2 计算 d_loss1(没有 layer1),以及(2)我希望针对 layer3 和 layer1 计算 d_loss2 - 没有 layer2。本质上,我想单独训练网络的非连续部分,并单独损失。
我相信我可以通过在 layer2 的输入中引入停止梯度来解决(1),如下所示:
forward(self, input)
x = self.layer1(input)
y = self.layer2(x)
y_stop_gradient = self.layer2(Variable(x.data))
z = self.layer3(y)
return y_stop_gradient, z
但是我该如何解决(2)?换句话说,我希望 loss2 的梯度能够“跳过”layer2 ,同时保持layer2 对 loss1 的可训练性。
解决方案
在等待正确答案的同时,我找到了自己的答案,尽管它的效率非常低,我希望其他人能提出更好的解决方案。
我的解决方案如下所示:
import copy
forward(self, input)
x = self.layer1(input)
y = copy.deepcopy(self.layer2)(x) # create a full copy of the layer
y_stop_gradient = self.layer2(Variable(x.data))
z = self.layer3(y)
return y_stop_gradient, z
这个解决方案效率低下,因为(1)我认为深拷贝对于我正在尝试做的事情来说太过分了,而且成本太高,(2)仍然计算 layer2 相对于 z 的梯度,它们只是未使用。
推荐阅读
- sql - 使用 1 个 SQL 查询连接 3 个表,并找到为每个客户群产生最多收入的产品类别
- python - 使用 Graph、Vertex 和 Edge 数据结构在 Python 中创建地铁站网络
- kotlin - Micronaut 数据 JDBC 嵌套实体
- spring-boot - 谷歌浏览器当我点击电子邮件验证链接新标签加载几秒钟然后关闭为什么?
- r - 计算 R 中两个掷骰子总和的概率
- sql - 根据评估几列中的某些列是否为真来创建计算列?
- python - 在 Python 中查找井字游戏的结果
- sql - sql join 给出累积结果
- python - 如何将变量分配给要排序的数字
- javascript - 从父组件动态传递 v-date-picker 的默认日期并从子组件返回更改的日期