python - 在 tf.keras 的 GAN 实现中正确设置 .trainable 变量
问题描述
我对 GAN 的实现中的.trainable
陈述感到困惑。tf.keras.model
鉴于以下代码被剪断(取自此 repo):
class GAN():
def __init__(self):
...
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
...
return Model(noise, img)
def build_discriminator(self):
...
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
在模型定义期间,self.combined
鉴别器的权重被设置为self.discriminator.trainable = False
但从未重新打开。
尽管如此,在训练循环期间,鉴别器的权重会随着线条的变化而变化:
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
并将在以下期间保持不变:
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
这是我没想到的。
当然,这是训练 GAN 的正确(迭代)方式,但我不明白为什么我们不必通过self.discriminator.trainable = True
才能对判别器进行一些训练。
如果有人对此有解释会很好,我想这是理解的关键点。
解决方案
当您对 github 存储库中的代码有疑问时,检查问题(打开和关闭)通常是一个好主意。 此问题解释了为什么将标志设置为False
. 它说,
由于
self.discriminator.trainable = False
是在编译器编译后设置的,所以不会影响判别器的训练。但是,由于它是在编译组合模型之前设置的,因此在训练组合模型时,鉴别器层将被冻结。
并且还谈到了冻结 keras 层。
推荐阅读
- vue.js - VueJS:同时使用 v-model 和 :value
- docker - 如何使用具有代码质量的 GitLab CI?
- jquery - 从 JQuery JSON 数组创建比较表
- html - 需要删除特定的 HTML 标签
- symfony - 在嵌入式表单集合中获取实体
- javascript - 为什么 POST 请求在 iOS Safari 中不起作用?
- validation - 编写验证代码时使用什么代替异常?
- angular - 我想滚动到页面顶部,因为我的列表包含一长串数据,我该怎么做?
- mysql - MySQL CASE WHEN numb IS NULL 忽略记录 WHERE numb IS NOT NULL
- excel - 如何使用 VBA 省略 arglist 中的最后一个变量