首页 > 解决方案 > 将 torch.backward() 用于 GAN 生成器时,为什么 Pytorch 中的判别器损失没有变化?

问题描述

我对 GAN 的理解是:

  1. 在训练您的生成器时,您需要先通过判别器进行反向传播,以便遵循链式规则。因此,我们.detach() 在进行生成器损失计算时不能使用 a。

  2. 更新判别器时,由于您的生成器权重更新不会影响判别器权重更新,我们可以.detach()从您的计算中生成生成器输出,我的理解告诉我,由于该生成器不再是计算图的一部分,因此我们不会在期间更新它不再支持后退。

因此,当我们更新您的鉴别器损失时:

disc_loss.backward(retain_graph=True).detach()在每个小批量中,由于函数调用,我们不必担心您的生成器会成为管道的一部分。

但是当我们在我们的生成器上工作时呢?是什么阻止了我们的模型根据生成器不断改变我们的鉴别器权重?毕竟这不是我们想要的吗?当被告知假样本是真实的时,鉴别器不应该学习。

为什么这样的模型首先会起作用

标签: pythonpytorchgenerative-adversarial-network

解决方案


backward不更新权重,它更新权重的梯度。更新权重是优化器的责任。实现 GAN 有不同的方法,但通常会有两个优化器,一个负责更新生成器的权重(并重置梯度),另一个负责更新生成器的权重(并重置梯度)鉴别器。在初始化时,每个优化器只提供它将更新的模型的权重。因此,当您调用优化器的 step 方法时,它只会更新这些权重。使用单独的优化器可以防止鉴别器权重被更新,同时最小化生成器的损失函数。


推荐阅读