首页 > 解决方案 > 如何在 Pytorch 中清除手动优化器的渐变

问题描述

我想要做的是计算 Projected Gradient Ascent wrt input X(在我的情况下为图像),同时禁用参数梯度。

这是我的代码:

def func(x: torch.Tensor, y: torch.Tensor,
                       network: nn.Module, loss_func: nn.Module, eta: float = .1, steps: int = 10):

    network.requires_grad_(False)#don't calculate grads for the parameters of the network
    x_copy = x.clone().requires_grad_(True)#copy from the main input 
    for i in range(steps):  
        pred = network(x_copy)#calculating predictions for the network
        loss = loss_func(pred,y)#calculation of the loss
        x_copy.retain_grad()#retain gradients of the input
        
        loss.backward()
        
        x_copy = x_copy + eta * x_copy.grad.sign()#gradient calculation


    return x_copy
        

通常,为了清除毕业生,我们optimiser.zero_grad()在每个loss.backward(). 但是如何使用手动梯度优化器来做到这一点?现在我收到一个错误:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

标签: pytorch

解决方案


推荐阅读