首页 > 解决方案 > RuntimeError,更新 SSH 服务器后 TBackward 版本错误,我该怎么办?

问题描述

目前正在使用 PyTorch pix-2-pix 模型完成我的论文,该模型已在 SSH 服务器上运行并且最近已更新。自更新以来,我在尝试运行代码时收到以下错误消息。有什么方法可以将我的用户降级回“旧”版本的 TBackward?

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-abea6d9413b2> in <module>()
     31         if update_G:
     32             optimizer_G.zero_grad()
---> 33             losses[1].backward(retain_graph=True)
     34             optimizer_G.step()
     35 

/home/u1269991/.local/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

/home/u1269991/.local/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
--> 132         allow_unreachable=True)  # allow_unreachable flag
    133 
    134 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [41472, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

损失周围的代码错误。损失代码如下

def forward_n_get_loss(real_img, cond_img, G, D, c=100):
    #
    fake_img = G(cond_img) 
    real_pair = torch.cat(([real_img, cond_img]), 1)
    fake_pair = torch.cat(([fake_img, cond_img]), 1)
    prob_real, prob_fake = D(real_pair), D(fake_pair.detach())
    # loss_D
    loss_D_real = nn.BCELoss()(prob_real, torch.ones_like(prob_real))
    loss_D_fake = nn.BCELoss()(prob_fake, torch.zeros_like(prob_fake))
    loss_D = (loss_D_real + loss_D_fake) *0.5
    # loss_G
    loss_G_fake = nn.BCELoss()(prob_fake, torch.ones_like(prob_fake))
    loss_G_FLIP_l1 = FLIP_l1_loss(fake_img, real_img) 
    loss_G = loss_G_fake + c*loss_G_FLIP_l1 
    return loss_D, loss_G, loss_D_real, loss_D_fake, loss_G_fake, loss_G_FLIP_l1

它执行backward呼叫的位置。

# train!
import time
start_epo = 0
n_epochs = 20
print_freq = 10
st = time.time()
loss2npval = lambda loss : np.mean(loss.cpu().data.numpy()).round(4)
loss_names = ["loss_D", "loss_G", "loss_D_real", "loss_D_fake",  \
              "loss_G_fake", "loss_G_FLIP_l1"]
update_DnG_together = True
G.train()
D.train()
for epo in range(n_epochs):
    for batch_No, (real_img_batch, cond_img_batch) in enumerate(train_loader):
        
        if use_cuda: real_img_batch, cond_img_batch = real_img_batch.cuda(), cond_img_batch.cuda() 
            
        # forward and get loss
        c = 100
        #c = (1-(epo+1)/n_epochs)*100+1 # decay c with epoch
        losses = forward_n_get_loss(real_img_batch, cond_img_batch, G, D, c=c )

        update_D, update_G = (True, True) if update_DnG_together  \
        else ((epo+batch_No)%2!=0, (epo+batch_No)%2==0  )
        
        # update D
        if update_D:
            optimizer_D.zero_grad()
            losses[0].backward(retain_graph=True)
            optimizer_D.step()
            
        # update G
        if update_G:
            optimizer_G.zero_grad()
            losses[1].backward(retain_graph=True)
            optimizer_G.step()
        
    if epo==0 or epo%print_freq==(print_freq-1) or epo==n_epochs-1: 
        et = time.time() 
        loss_vals = map(loss2npval, losses)
        loss_info = dict(zip(loss_names,loss_vals))
        print("[{}] {}, time_cost: {:.2f} min" \
              .format( epo, loss_info, (et-st)/60 ))
        st = et # update st
        # save out
        this_epo_str = str(epo+start_epo).zfill(4) 
        path = '/content/drive/MyDrive/Thesis UvT'
        torch.save(G.state_dict(), f"{path}/PyTorch/G_{this_epo_str}")
        torch.save(D.state_dict(), f"{path}/PyTorch/D_{this_epo_str}")
        sample_img(G, train_loader, f"{path}/PyTorch/{this_epo_str}.png")

标签: pythonpytorch

解决方案


第一个问题(这与我稍后会谈到的错误无关)是您使用prob_fakeBCE 损失来计算生成器损失。然而prob_fake,它与发电机分离。这意味着梯度不会反向传播到生成器,也永远不会训练。相反,您应该通过鉴别器推断假(生成)实例而不分离, D(fake_pair).

第二个问题是您执行了两个反向传播和步骤,这意味着第一个在鉴别器上。这意味着在您反向传播生成器 loss 之前,鉴别器的权重将发生变化loss_G。这意味着您需要loss_G_fake 使用更新的鉴别器权重进行推断。

防止这种副作用的一种方法是对损失进行反向传播,并仅.


推荐阅读