首页 > 解决方案 > 分离 Pytorch 中关于部分损失的中间模块

问题描述

假设我有以下前向传递,导致两个单独的损失:

forward(self, input)
    x = self.layer1(input)

    y = self.layer2(x)

    z = self.layer3(y)

    return y, z

然后我们计算 loss1(y) 和 loss2(z)。然后我们可以loss = loss1 + loss2使用单个优化器进行优化。

但是我有两个警告:(1)我希望仅针对 layer2 计算 d_loss1(没有 layer1),以及(2)我希望针对 layer3 和 layer1 计算 d_loss2 - 没有 layer2。本质上,我想单独训练网络的非连续部分,并单独损失。

我相信我可以通过在 layer2 的输入中引入停止梯度来解决(1),如下所示:

forward(self, input)
    x = self.layer1(input)

    y = self.layer2(x)
    y_stop_gradient = self.layer2(Variable(x.data))

    z = self.layer3(y)

    return y_stop_gradient, z

但是我该如何解决(2)?换句话说,我希望 loss2 的梯度能够“跳过”layer2 ,同时保持layer2 对 loss1 的可训练性。

标签: deep-learningpytorchgradient-descent

解决方案


在等待正确答案的同时,我找到了自己的答案,尽管它的效率非常低,我希望其他人能提出更好的解决方案。

我的解决方案如下所示:

import copy
forward(self, input)
    x = self.layer1(input)

    y = copy.deepcopy(self.layer2)(x)  # create a full copy of the layer
    y_stop_gradient = self.layer2(Variable(x.data))

    z = self.layer3(y)

    return y_stop_gradient, z

这个解决方案效率低下,因为(1)我认为深拷贝对于我正在尝试做的事情来说太过分了,而且成本太高,(2)仍然计算 layer2 相对于 z 的梯度,它们只是未使用。


推荐阅读