首页 > 解决方案 > 大型 WGAN-GP 训练损失

问题描述

这是 WGAN-GP 的损失函数

gen_sample = model.generator(input_gen)
disc_real = model.discriminator(real_image, reuse=False)
disc_fake = model.discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)
# Gradient penalty
alpha = tf.random_uniform(
    shape=[BATCH_SIZE, 1, 1, 1],
    minval=0.,
    maxval=1.)
differences = gen_sample - real_image
interpolates = real_image + (alpha * differences)
gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0]    # why [0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)

d_loss_real = tf.reduce_mean(disc_real)
d_loss_fake = tf.reduce_mean(disc_fake)

disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty
gen_loss = - d_loss_fake

这是训练损失

发电机损失在振荡,价值如此之大。我的问题是:发电机损耗是正常还是异常?

标签: pythontensorflowdeep-learning

解决方案


需要注意的一件事是您的梯度惩罚计算是错误的。以下行:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

实际上应该是:

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))

您在第一个轴上减少,但渐变基于 alpha 值所示的图像,因此您必须在轴上减少[1,2,3]

您的代码中的另一个错误是生成器损失是:

gen_loss = d_loss_real - d_loss_fake

对于梯度计算,这没有区别,因为生成器的参数仅包含在 d_loss_fake 中。然而,对于发电机损失的价值,这使得世界上的一切都变得不同,这也是为什么会如此频繁的原因。

归根结底,您应该查看您关心的实际性能指标,以确定 GAN 的质量,例如初始分数或 Fréchet 初始距离 (FID),因为判别器和生成器的损失只是描述性的。


推荐阅读