python-3.x - 为什么我的 GAN 在某个点之后没有产生更多好的图像?
问题描述
问题
我正在训练一个 gan 来生成人脸。在大约 500 个 epoch 内,它学会了生成如下图像:
好吧,现在这个形象还不错。我们可以在图像的中心看到一张脸。
然后我对它进行了 1000 多个 epoch 的训练,但它什么也没学到。它仍在生成与上图相同类型的图像。那是为什么?为什么我的 gan 没有学会制作更好的图像?
模型代码
这是鉴别器的代码:
def define_discriminator(in_shape=(64, 64, 3)):
Model = Sequential([
Conv2D(32, (3, 3), padding='same', input_shape=in_shape),
BatchNormalization(),
LeakyReLU(alpha=0.2),
MaxPooling2D(pool_size=(2,2)),
Dropout(0.2),
Conv2D(64, (3,3), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
MaxPooling2D(pool_size=(2,2)),
Dropout(0.3),
Conv2D(128, (3,3), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
MaxPooling2D(pool_size=(2,2)),
Dropout(0.3),
Conv2D(256, (3,3), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
MaxPooling2D(pool_size=(2,2)),
Dropout(0.4),
Flatten(),
Dense(1, activation='sigmoid')
])
opt = Adam(lr=0.00002)
Model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return Model
下面是生成器和 GAN 的代码:
def define_generator(in_shape=100):
Model = Sequential([
Dense(256*8*8, input_dim=in_shape),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Reshape((8, 8, 256)),
Conv2DTranspose(256, (3,3), strides=(2,2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(64, (3,3), strides=(2,2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(3, (4, 4), strides=(2,2), padding='same', activation='sigmoid')
])
return Model
def define_gan(d_model, g_model):
d_model.trainable = False
model = Sequential([
g_model,
d_model
])
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
整个可重现的代码
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization
from tensorflow.keras.layers import Dropout, Flatten, Dense, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, Activation, Reshape, LeakyReLU
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam
from numpy import ones
from numpy import zeros
from numpy.random import rand
from numpy.random import randint
from numpy.random import randn
from numpy import vstack
from numpy import array
import os
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from matplotlib import pyplot
def load_data(filepath):
image_array = []
n = 0
for fold in os.listdir(filepath):
if fold != 'wiki.mat':
if n > 1:
break
for img in os.listdir(os.path.join(filepath, fold)):
image = load_img(filepath + fold + '/'+ img, target_size=(64, 64))
img_array = img_to_array(image)
img_array = img_array.astype('float32')
img_array = img_array / 255.0
image_array.append(img_array)
n += 1
return array(image_array)
def generate_latent_points(n_samples, latent_dim=100):
latent_points = randn(n_samples * latent_dim)
latent_points = latent_points.reshape(n_samples, latent_dim)
return latent_points
def generate_real_samples(n_samples, dataset):
ix = randint(0, dataset.shape[0], n_samples)
x = dataset[ix]
y = ones((n_samples, 1))
return x, y
def generate_fake_samples(g_model, n_samples):
latent_points = generate_latent_points(n_samples)
x = g_model.predict(latent_points)
y = zeros((n_samples, 1))
return x, y
def save_plot(examples, epoch, n=10):
# plot images
for i in range(n * n):
# define subplot
pyplot.subplot(n, n, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(examples[i, :, :, 0])
# save plot to file
filename = 'generated_plot_e%03d.png' % (epoch+1)
pyplot.savefig(filename)
pyplot.close()
def summarize_performance(d_model, g_model, gan_model, dataset, epoch, n_samples=100):
real_x, real_y = generate_real_samples(n_samples, dataset)
_, d_real_acc = d_model.evaluate(real_x, real_y)
fake_x, fake_y = generate_fake_samples(g_model, n_samples)
_, d_fake_acc = d_model.evaluate(fake_x, fake_y)
latent_points, y = generate_latent_points(n_samples), ones((n_samples, 1))
gan_loss = gan_model.evaluate(latent_points, y)
print('Epoch %d, acc_real=%.3d, acc_fake=%.3f, gan_loss=%.3f' % (epoch, d_real_acc, d_fake_acc, gan_loss))
save_plot(fake_x, epoch)
filename = 'Genarator_Model % d' % (epoch + 1)
g_model.save(filename)
def train(d_model, g_model, gan_model, dataset, epochs=200):
batch_size = 64
half_batch = int(batch_size / 2)
batch_per_epoch = int(dataset.shape[0] / batch_size)
for epoch in range(epochs):
for i in range(batch_per_epoch):
real_x, real_y = generate_real_samples(half_batch, dataset)
_, d_real_acc = d_model.train_on_batch(real_x, real_y)
fake_x, fake_y = generate_fake_samples(g_model, half_batch)
_, d_fake_acc = d_model.train_on_batch(fake_x, fake_y)
latent_points, y = generate_latent_points(batch_size), ones((batch_size, 1))
gan_loss = gan_model.train_on_batch(latent_points, y)
print('Epoch %d, acc_real=%.3d, acc_fake=%.3f, gan_loss=%.3f' % (epoch, d_real_acc, d_fake_acc, gan_loss))
if (epoch % 2) == 0:
summarize_performance(d_model, g_model, gan_model, dataset, epoch)
dataset = load_data(filepath) # filepath is not defined since every person will have seperate filepath
discriminator_model = define_discriminator()
generator_model = define_generator()
gan_model = define_gan(discriminator_model, generator_model)
train(discriminator_model, generator_model, gan_model, dataset)
数据集
如果你想要这里是数据集。
解决方案
推荐阅读
- android - 我可以在片段活动之外替换片段布局吗
- ios - 如何使用 reloadItems 淡化项目(在:[indexPath])
- amazon-web-services - aws 上的异步 lambda 函数是按顺序调用的吗?
- node.js - 保存 mongoose 数组模式类型时忽略 null 或空值
- php - PHP Heroku 应用程序不允许我访问子目录
- typescript - 在返回承诺的函数上使用 .then 会抛出“这是不可调用的”
- javascript - 试图创建一个有分数和时间的排行榜。本地存储?
- java - FindBugs 插件 Eclipse 发现的错误
- angular - 如何在不从根 URL 开始的情况下使角度链接工作?
- angular - Angular SSR 无需在点击链接上重新加载页面