python-3.x - 在GAN中,是否需要编译生成器
问题描述
我一直在研究 GAN,让我摸不着头脑的是为什么我们必须编译生成器模型,即使我们编译组合的 GAN 模型,为什么还要单独编译生成器。
def create_generator():
generator = Sequential()
generator.add(Dense(256, input_dim=noise_dim))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(img_rows*img_cols*channels, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)
return generator
def create_descriminator():
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=img_rows*img_cols*channels))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
return discriminator
discriminator = create_descriminator()
generator = create_generator()
# Make the discriminator untrainable when we are training the generator. This doesn't effect the discriminator by itself
discriminator.trainable = False
# Link the two models to create the GAN
gan_input = Input(shape=(noise_dim,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
在这段代码中,可以看到生成器、判别器和 gan(组合模型)这三个都被编译了。根据我的理解,我们应该只编译鉴别器(训练鉴别器)和 gan(组合模型,训练生成器),因为鉴别器权重在 GAN 训练期间被冻结,结果只有生成器得到训练。那么为什么要编译生成器
解决方案
在训练期间,thegenerator
和the 的discriminator
目标相反:discriminator
尝试将假图像与真实图像区分开来,而生成器尝试生成看起来足够真实的图像以欺骗鉴别器。
因为 GAN 由两个具有不同目标的网络组成,所以不能像常规神经网络那样进行训练。每次训练迭代分为两个阶段:
- 在第一阶段,我们训练判别器。从训练集中抽取一批真实图像,并用生成器生成的相同数量的假图像完成。假图像的标签设置为 0,真实图像的标签设置为 1,鉴别器在这个标记的批次上训练一步,使用二元交叉熵损失。重要的是,反向传播仅在此阶段优化鉴别器的权重。
- 在第二阶段,我们训练生成器。我们首先用它来产生另一批假图像,再一次用鉴别器来判断图像是假的还是真的。这次我们不在批次中添加真实图像,所有标签都设置为 1(真实):换句话说,我们希望生成器生成判别器(错误地)认为是真实的图像!至关重要的是,
discriminator
在frozen
此步骤中的权重,因此反向传播仅影响生成器的权重。
接下来,我们需要编译这些模型。generator
只会通过 训练,
gan model
所以我们根本不需要编译它。重要的是,discriminator
不应该在第二阶段训练,所以我们在 gan 模型non-trainable
之前
进行训练:compiling
推荐阅读
- web-crawler - Puppeteer Crawler 大型爬行
- java - 比较Java POJO的最简单方法是什么
- openssl - 如何修复 Windows 10 上的 libssl.lib not found 错误?
- mysql - 选择总和直到一定数量,然后根据条件更新某些字段
- javascript - 将数组中的任意值映射到另一个数组中存在的颜色
- php - 如何根据 html 日期选择器选择的值从数据库中选择记录
- infinispan - 错误/Infinispan 9.4.20.Final 到 10.1.8.Final 之间事务行为的变化
- c - 搜索不存在的项目时,链表出现分段错误
- javascript - Unicode编码问题
- c - 我在 C 编程中遇到错误。([Error] 预期标识符或 '(' 在 '{' 标记之前) 。有人可以帮我解决一下,让我知道为什么吗?