首页 > 解决方案 > Pytorch 为 nn.Module 函数添加自定义反向传递

问题描述

我正在重新实现可逆残差网络架构。

class iResNetBlock(nn.Module):
    def __init__(self, input_size, hidden_size):
        self.bottleneck = nn.Sequential(
            LinearContraction(input_size, hidden_size),
            LinearContraction(hidden_size, input_size),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return x + self.bottleneck(x)
    
    def inverse(self, y):
        x = y.clone()

        while not converged:
            # fixed point iteration
            x = y - self.bottleneck(x)
   
        return x

我想为inverse函数添加一个自定义的反向传递。由于它是定点迭代,因此可以利用隐函数定理来避免循环展开,而是通过求解线性系统来计算梯度。例如,这是在深度平衡模型架构中完成的。

    def inverse(self, y):
        with torch.no_grad():
            x = y.clone()
            while not converged:
                # fixed point iteration
                x = y - self.bottleneck(x)
   
            return x

    def custom_backward_inverse(self, grad_output):
        pass

如何为此功能注册我的自定义反向通行证?我希望,当我稍后定义一些损失时 r = loss(y, model.inverse(other_model(model(x))))r.backwards()正确地使用我的自定义渐变进行反向调用。

理想情况下,解决方案应该是torchscript兼容的。

标签: pytorchtorchscript

解决方案


推荐阅读