python - 大型 WGAN-GP 训练损失
问题描述
这是 WGAN-GP 的损失函数
gen_sample = model.generator(input_gen)
disc_real = model.discriminator(real_image, reuse=False)
disc_fake = model.discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)
# Gradient penalty
alpha = tf.random_uniform(
shape=[BATCH_SIZE, 1, 1, 1],
minval=0.,
maxval=1.)
differences = gen_sample - real_image
interpolates = real_image + (alpha * differences)
gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0] # why [0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
d_loss_real = tf.reduce_mean(disc_real)
d_loss_fake = tf.reduce_mean(disc_fake)
disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty
gen_loss = - d_loss_fake
发电机损失在振荡,价值如此之大。我的问题是:发电机损耗是正常还是异常?
解决方案
需要注意的一件事是您的梯度惩罚计算是错误的。以下行:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
实际上应该是:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))
您在第一个轴上减少,但渐变基于 alpha 值所示的图像,因此您必须在轴上减少[1,2,3]
。
您的代码中的另一个错误是生成器损失是:
gen_loss = d_loss_real - d_loss_fake
对于梯度计算,这没有区别,因为生成器的参数仅包含在 d_loss_fake 中。然而,对于发电机损失的价值,这使得世界上的一切都变得不同,这也是为什么会如此频繁的原因。
归根结底,您应该查看您关心的实际性能指标,以确定 GAN 的质量,例如初始分数或 Fréchet 初始距离 (FID),因为判别器和生成器的损失只是描述性的。
推荐阅读
- php - 带有查询字符串的 Wordpress 分页错误
- macos - 在 macOS 上,如何打开隐藏或关闭屏幕的 gui .app?
- javascript - 我正在用 Twilio.Device Js 构建一个自动拨号器
- rust - 如果 Box 是协变的,这意味着什么
不是 Box 的子类型 哪里 B:A? - python - 如何在 % NANs 高于某个数字的情况下删除浮动功能?
- netsuite - NetSuite - 创建自定义 NetSuite CSV 导入字段
- excel - 如何在 vba 中编写以将多张工作表中的列捕获到一个字典对象中
- r - 即使代码显示不同的颜色场景,我的系统发育树仍然是彩虹调色板
- java - distinct() 方法不返回 HashSet 元素流的不同元素
- java - 为什么 ModelMapper 地图集合合并样式?