首页 > 解决方案 > 我应该如何将自己的训练数据加载到这个生成对抗网络中?

问题描述

我是 Python 的绝对初学者。最近,我将 MNIST 手写数字数据库加载到生成对抗网络中。该程序运行良好,但我想知道如何修改下面的代码,以便我可以加载我自己的训练数据,一个 JPG 文件夹,而不是 MNIST 数据库。有没有一种简单的方法可以用这段代码做到这一点?

我知道我需要将图像转换为 MNIST 格式,但除此之外,我不明白我必须包含和/或编辑哪些行才能加载文件夹。

感谢您的帮助!

import os
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
from tqdm import tqdm
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
os.environ["KERAS_BACKEND"] = "tensorflow"
np.random.seed(10)
random_dim = 100
def load_minst_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)
def get_optimizer():
        return Adam(lr=0.0002, beta_1=0.5)
def get_generator(optimizer):
        generator = Sequential()
        generator.add(Dense(256, input_dim=random_dim, 
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        generator.add(LeakyReLU(0.2))
        generator.add(Dense(512))
        generator.add(LeakyReLU(0.2))
        generator.add(Dense(1024))
        generator.add(LeakyReLU(0.2))
        generator.add(Dense(784, activation='tanh'))
        generator.compile(loss='binary_crossentropy', optimizer=optimizer)
        return generator
def get_discriminator(optimizer):
        discriminator = Sequential()
        discriminator.add(Dense(1024, input_dim=784, 
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
        discriminator.add(Dense(512))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
        discriminator.add(Dense(256))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
        discriminator.add(Dense(1, activation='sigmoid'))
        discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
       return discriminator
def get_gan_network(discriminator, random_dim, generator, optimizer):
        discriminator.trainable = False
        gan_input = Input(shape=(random_dim,))
        x = generator(gan_input)
        gan_output = discriminator(x)
        gan = Model(inputs=gan_input, outputs=gan_output)
        gan.compile(loss='binary_crossentropy', optimizer=optimizer)
        return gan
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), 
figsize=(10, 10)):
        noise = np.random.normal(0, 1, size=[examples, random_dim])
        generated_images = generator.predict(noise)
        generated_images = generated_images.reshape(examples, 28, 28)
        plt.figure(figsize=figsize)
        for i in range(generated_images.shape[0]):
                plt.subplot(dim[0], dim[1], i+1)
                plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
                plt.axis('off')
        plt.tight_layout()
        plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
def train(epochs=1, batch_size=128):
         x_train, y_train, x_test, y_test = load_minst_data()
         batch_count = x_train.shape[0] // batch_size
         adam = get_optimizer()
         generator = get_generator(adam)
         discriminator = get_discriminator(adam)
         gan = get_gan_network(discriminator, random_dim, generator, adam)
         for e in range(1, epochs+1):
                 print ('-'*15, 'Epoch %d' % e, '-'*15)
                 for _ in tqdm(range(batch_count)):
                         noise = np.random.normal(0, 1, size=[batch_size, random_dim])
                         image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
                         generated_images = generator.predict(noise)
                         X = np.concatenate([image_batch, generated_images])
                         y_dis = np.zeros(2*batch_size)
                         y_dis[:batch_size] = 0.9
                         discriminator.trainable = True
                         discriminator.train_on_batch(X, y_dis)
                         noise = np.random.normal(0, 1, size=[batch_size, random_dim])
                         y_gen = np.ones(batch_size)
                         discriminator.trainable = False
                         gan.train_on_batch(noise, y_gen)
                 if e == 1 or e % 20 == 0:
                         plot_generated_images(e, generator)
if __name__ == '__main__':
         train(400, 128)

标签: python

解决方案


MNIST是一个数据集而不是一种格式。

以下行是代码作者加载数据集的位置:

x_train, y_train, x_test, y_test = load_minst_data()

在:

def train(epochs=1, batch_size=128):
         x_train, y_train, x_test, y_test = load_minst_data()
         ...

这调用了函数:

def load_minst_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)

您可以对其进行修改以加载数据集。

此外,由于代码需要四个变量,x_train, y_train, x_test, y_test分别在:

(x_train, y_train), (x_test, y_test) = mnist.load_data()

我建议使用train_test_split。该函数将帮助您将数据拆分为上述变量。

以下行:

x_train = (x_train.astype(np.float32) - 127.5)/127.5

标准化训练样本。

在:

x_train = x_train.reshape(60000, 784)

作者将60000样本扁平化为一个大小784的向量,以便将它们提供给模型。

PS 可以重新784整形,因为 MNIST 最初是 28x28像素。

您还需要修改数据形状或更改input_dim

discriminator.add(Dense(1024, input_dim=784,kernel_initializer=initializers.RandomNormal(stddev=0.02)))

并且可能每一个存在784


推荐阅读