python - RuntimeError,更新 SSH 服务器后 TBackward 版本错误,我该怎么办?
问题描述
目前正在使用 PyTorch pix-2-pix 模型完成我的论文,该模型已在 SSH 服务器上运行并且最近已更新。自更新以来,我在尝试运行代码时收到以下错误消息。有什么方法可以将我的用户降级回“旧”版本的 TBackward?
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-abea6d9413b2> in <module>()
31 if update_G:
32 optimizer_G.zero_grad()
---> 33 losses[1].backward(retain_graph=True)
34 optimizer_G.step()
35
/home/u1269991/.local/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
219 retain_graph=retain_graph,
220 create_graph=create_graph)
--> 221 torch.autograd.backward(self, gradient, retain_graph, create_graph)
222
223 def register_hook(self, hook):
/home/u1269991/.local/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
130 Variable._execution_engine.run_backward(
131 tensors, grad_tensors_, retain_graph, create_graph,
--> 132 allow_unreachable=True) # allow_unreachable flag
133
134
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [41472, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
损失周围的代码错误。损失代码如下
def forward_n_get_loss(real_img, cond_img, G, D, c=100):
#
fake_img = G(cond_img)
real_pair = torch.cat(([real_img, cond_img]), 1)
fake_pair = torch.cat(([fake_img, cond_img]), 1)
prob_real, prob_fake = D(real_pair), D(fake_pair.detach())
# loss_D
loss_D_real = nn.BCELoss()(prob_real, torch.ones_like(prob_real))
loss_D_fake = nn.BCELoss()(prob_fake, torch.zeros_like(prob_fake))
loss_D = (loss_D_real + loss_D_fake) *0.5
# loss_G
loss_G_fake = nn.BCELoss()(prob_fake, torch.ones_like(prob_fake))
loss_G_FLIP_l1 = FLIP_l1_loss(fake_img, real_img)
loss_G = loss_G_fake + c*loss_G_FLIP_l1
return loss_D, loss_G, loss_D_real, loss_D_fake, loss_G_fake, loss_G_FLIP_l1
它执行backward
呼叫的位置。
# train!
import time
start_epo = 0
n_epochs = 20
print_freq = 10
st = time.time()
loss2npval = lambda loss : np.mean(loss.cpu().data.numpy()).round(4)
loss_names = ["loss_D", "loss_G", "loss_D_real", "loss_D_fake", \
"loss_G_fake", "loss_G_FLIP_l1"]
update_DnG_together = True
G.train()
D.train()
for epo in range(n_epochs):
for batch_No, (real_img_batch, cond_img_batch) in enumerate(train_loader):
if use_cuda: real_img_batch, cond_img_batch = real_img_batch.cuda(), cond_img_batch.cuda()
# forward and get loss
c = 100
#c = (1-(epo+1)/n_epochs)*100+1 # decay c with epoch
losses = forward_n_get_loss(real_img_batch, cond_img_batch, G, D, c=c )
update_D, update_G = (True, True) if update_DnG_together \
else ((epo+batch_No)%2!=0, (epo+batch_No)%2==0 )
# update D
if update_D:
optimizer_D.zero_grad()
losses[0].backward(retain_graph=True)
optimizer_D.step()
# update G
if update_G:
optimizer_G.zero_grad()
losses[1].backward(retain_graph=True)
optimizer_G.step()
if epo==0 or epo%print_freq==(print_freq-1) or epo==n_epochs-1:
et = time.time()
loss_vals = map(loss2npval, losses)
loss_info = dict(zip(loss_names,loss_vals))
print("[{}] {}, time_cost: {:.2f} min" \
.format( epo, loss_info, (et-st)/60 ))
st = et # update st
# save out
this_epo_str = str(epo+start_epo).zfill(4)
path = '/content/drive/MyDrive/Thesis UvT'
torch.save(G.state_dict(), f"{path}/PyTorch/G_{this_epo_str}")
torch.save(D.state_dict(), f"{path}/PyTorch/D_{this_epo_str}")
sample_img(G, train_loader, f"{path}/PyTorch/{this_epo_str}.png")
解决方案
第一个问题(这与我稍后会谈到的错误无关)是您使用prob_fake
BCE 损失来计算生成器损失。然而prob_fake
,它与发电机分离。这意味着梯度不会反向传播到生成器,也永远不会训练。相反,您应该通过鉴别器推断假(生成)实例而不分离,即 D(fake_pair)
.
第二个问题是您执行了两个反向传播和步骤,这意味着第一个在鉴别器上。这意味着在您反向传播生成器 loss 之前,鉴别器的权重将发生变化loss_G
。这意味着您需要loss_G_fake
使用更新的鉴别器权重进行推断。
防止这种副作用的一种方法是对损失进行反向传播,并仅在.
推荐阅读
- passenger - 如何在 Phusion 乘客重启时处理应用程序清理?
- mysql - 如何构建一个 SQL 来查找每个产品在不同日期和时间的 MAX?
- c# - C# - 强制窗口延迟 20 秒左右以实现蓝牙未连接
- python - Python 3.7 Scapy 安装失败
- reactjs - React “提交主机效果”在分析器中大约需要 13 秒。如何在 Chrome 开发工具中分解这些提交?
- open-source - 如何阻止对象存储服务文件链接过期?
- jquery - 使用ajax jquery在localhost中出现CORS错误
- c# - 如何从控制器返回多个视图包数据以查看?
- d3.js - c3.js制作以置信区间为区域的阶梯函数图
- c++ - 打印二维字符数组的最快方法是什么?