python - 逐渐衰减损失函数的权重
问题描述
我不确定是问这个问题的正确地方,如果我需要删除帖子,请随时告诉我。
我是 pyTorch 的新手,目前正在使用 CycleGAN(pyTorch 实现)作为我项目的一部分,并且我了解 cycleGAN 的大部分实现。
我阅读了名为“CycleGAN with better Cycles”的论文,我正在尝试应用论文中提到的修改。修改之一是循环一致性权重衰减,我不知道如何应用。
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle consistency loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN +
lambda_cyc * loss_cycle + #lambda_cyc is 10
lambda_id * loss_identity #lambda_id is 0.5 * lambda_cyc
loss_G.backward()
optimizer_G.step()
我的问题是如何逐渐减轻循环一致性损失的权重?
任何帮助实施此修改将不胜感激。
这来自论文:循环一致性损失有助于在早期阶段大量稳定训练,但在后期阶段成为现实图像的障碍。我们建议随着训练进度逐渐衰减循环一致性损失 λ 的权重。但是,我们仍然应该确保 λ 不会衰减到 0,这样生成器就不会变得不受约束并完全疯狂。
提前致谢。
解决方案
下面是一个可以使用的原型函数!
def loss (other params, decay params, initial_lambda, steps):
# compute loss
# compute cyclic loss
# function that computes lambda given the steps
cur_lambda = compute_lambda(step, decay_params, initial_lamdba)
final_loss = loss + cur_lambda*cyclic_loss
return final_loss
compute_lambda
以 50 步从 10 线性衰减到 1e-5 的函数
def compute_lambda(step, decay_params):
final_lambda = decay_params["final"]
initial_lambda = decay_params["initial"]
total_step = decay_params["total_step"]
start_step = decay_params["start_step"]
if (step < start_step+total_step and step>start_step):
return initial_lambda + (step-start_step)*(final_lambda-initial_lambda)/total_step
elif (step < start_step):
return initial_lambda
else:
return final_lambda
# Usage:
compute_lambda(i, {"final": 1e-5, "initial":10, "total_step":50, "start_step" : 50})
推荐阅读
- c - 在 C 中使用 MPI_Isend 发送多个非阻塞消息并使用 MPI_Recv 接收的问题
- excel - 从另一个工作簿复制工作表,其中工作表名称可以根据输入而更改
- powershell - 如何将结果保存到不同的变量
- c# - 测试用例在带有 https 的 docker windows 控制台中失败 vstest.console.exe
- shader - 一个 Texture2DArray 上的多个着色器资源视图
- node.js - 如何在实时服务器上自动运行 node js (express) 应用程序?
- c++ - 由于大输入的堆栈溢出,将通用递归转换为尾递归
- asp.net - 我怎样才能做跨页张贴?那更好?
- qt - 使用自己的 editorEvent() 处理的 QItemDelegate 捕获我想在 QTreeView 中接收的鼠标右键单击
- javascript - 我想在不和谐频道发送一个文件,机器人会读取这个文件