首页 > 解决方案 > 向keras网络添加类信息

问题描述

我试图弄清楚如何将我的数据集的标签信息与生成对抗网络一起使用。我正在尝试使用可以在此处找到的条件 GAN 的以下实现。我的数据集包含两个不同的图像域(真实对象和草图),具有共同的类信息(椅子、树、橙色等)。我选择了这种实现,它只将两个不同的域视为对应的不同“类”(训练样本X对应于真实图像,而目标样本y对应于草图图像)。

有没有办法修改我的代码并在我的整个架构中考虑类信息(椅子、树等)?实际上,我希望我的鉴别器预测我从生成器生成的图像是否属于特定类别,而不仅仅是它们是否真实。事实上,使用当前架构,系统学习在所有情况下创建类似的草图。

更新:鉴别器返回一个大小的张量,1x7x7然后在计算损失y_true之前y_pred通过一个展平层:

def discriminator_loss(y_true, y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

以及鉴别器对生成器的损失函数:

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

此外,我对输出 1 层的鉴别器模型的修改:

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))

现在鉴别器输出 1 层。如何相应修改上述损失函数?n_classes = 6对于预测真假配对的 + 一类,我应该使用 7 而不是 1吗?

标签: pythonkerasconv-neural-networkloss-functiongenerative-adversarial-network

解决方案


建议的解决方案

重用您共享的存储库中的代码,这里有一些建议的修改,以沿着您的生成器和鉴别器训练分类器(它们的架构和其他损失保持不变):

from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D

def lenet_classifier_model(nb_classes):
    # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
    # Replace with your favorite classifier...
    model = Sequential()
    model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(180, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(100, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes, activation='softmax', init='he_normal'))

def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
    inputs = Input((IN_CH, img_cols, img_rows))
    x_generator = generator(inputs)

    merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
    discriminator.trainable = False
    x_discriminator = discriminator(merged)

    classifier.trainable = False
    x_classifier = classifier(x_generator)

    model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])

    return model


def train(BATCH_SIZE):
    (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
    discriminator = discriminator_model()
    generator = generator_model()
    classifier = lenet_classifier_model(6)
    generator.summary()
    discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
        generator, discriminator, classifier)
    d_optim = Adagrad(lr=0.005)
    g_optim = Adagrad(lr=0.005)
    generator.compile(loss='mse', optimizer="rmsprop")
    discriminator_and_classifier_on_generator.compile(
        loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
        optimizer="rmsprop")
    discriminator.trainable = True
    discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
    classifier.trainable = True
    classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")

    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
        for index in range(int(X_train.shape[0] / BATCH_SIZE)):
            image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
            label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here

            generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image * 127.5 + 127.5
                image = np.swapaxes(image, 0, 2)
                cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
                # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")

            # Training D:
            real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
                                        axis=1)
            fake_pairs = np.concatenate(
                (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
            X = np.concatenate((real_pairs, fake_pairs))
            y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            discriminator.trainable = False

            # Training C:
            c_loss = classifier.train_on_batch(image_batch, label_batch)
            print("batch %d c_loss : %f" % (index, c_loss))
            classifier.trainable = False

            # Train G:
            g_loss = discriminator_and_classifier_on_generator.train_on_batch(
                X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
                [image_batch, np.ones((10, 1, 64, 64)), label_batch])
            discriminator.trainable = True
            classifier.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss[1]))
            if index % 20 == 0:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)

理论细节

我认为对于条件 GAN 的工作原理以及鉴别器在此类方案中的作用存在一些误解。

鉴别器的作用

在 GAN 训练 [4] 的 min-max 游戏中,判别D器与生成器G(您真正关心的网络)进行对抗,以便在D的审查下,G在输出真实结果方面变得更好。

为此,D经过训练,可以将真实样本与来自 的样本区分开来G;whileG被训练为D通过在目标分布之后生成真实的结果/结果来愚弄。

注意:在条件 GAN 的情况下,即 GAN 将输入样本从一个域A(例如真实图片)映射到另一个域B (例如草图),D通常输入堆叠在一起的样本对,并且必须区分“真实”对(来自的输入样本A+ 来自 的相应目标样本B)和“假”对(来自 的输入样本A+ 来自 的相应输出G [1, 2]

训练条件生成器D(而不是简单地G单独训练,仅使用 L1/L2 损失,例如 DAE)提高了 的采样能力G,迫使它输出清晰、真实的结果,而不是试图平均分布。

即使鉴别器可以有多个子网络来覆盖其他任务(见下一段),D也应该至少保留一个子网络/输出来覆盖其主要任务:将真实样本与生成的样本区分开来。要求D回归进一步的语义信息(例如类)可能会干扰这个主要目的。

注意:D输出通常不是简单的标量/布尔值。通常有一个鉴别器(例如 PatchGAN [1, 2])返回一个概率矩阵,评估从其输入生成的补丁的真实性。


条件 GAN

传统的 GAN 以无监督的方式进行训练,以从作为输入的随机噪声向量生成真实数据(例如图像)。[4]

如前所述,条件 GAN 具有进一步的输入条件。沿着/而不是噪声向量,它们将来自域的样本作为输入,A并从域返回相应的样本BA可以是完全不同的模态,例如B = sketch imagewhile A = discrete label; B = volumetric dataA = RGB image等 [3]

这样的 GAN 也可以通过多个输入来调节,例如A = real image + discrete labelwhile B = sketch image。介绍这种方法的著名工作是InfoGAN [5]。它介绍了如何在多个连续或离散输入例如A = digit class + writing typeB = handwritten digit imageG


最大化 cGAN 的互信息

InfoGAN 鉴别器有 2 个头/子网络来覆盖其 2 个任务 [5]:

  • 一个负责D1人进行传统的真实/生成的区分——G必须最小化这个结果,即它必须愚弄D1以使其无法区分真实形式的生成数据;
  • 另一个头D2(也称为Q网络)试图回归输入A信息 -G必须最大化这个结果,即它必须输出“显示”请求的语义信息的数据(参见G条件输入与其输出之间的互信息最大化)。

例如,您可以在此处找到 Keras 实现:https ://github.com/eriklindernoren/Keras-GAN/tree/master/infogan 。

一些工作正在使用类似的方案来改进对 GAN 生成内容的控制,方法是使用提供的标签并最大化这些输入和G输出之间的互信息 [6, 7]。基本思想总是相同的:

  • 给定域的一些输入,训练G生成域的元素;BA
  • 训练D区分“真实”/“假”结果——G必须尽量减少这一点;
  • 训练Q(例如分类器;可以与 共享层)以估计来自样本D的原始A输入——必须最大化这一点)。BG

包起来

在您的情况下,您似乎有以下训练数据:

  • 真实图像Ia
  • 相应的草图图像Ib
  • 对应的类标签c

你想训练一个生成器G,以便给定图像Ia及其类标签c,它会输出正确的草图图像Ib'

总而言之,你有很多信息,你可以监督你在条件图像和条件标签上的训练......受上述方法 [1, 2, 5, 6, 7] 的启发,这里有一个使用所有这些信息来训练你的条件的可能方法G

网络G
  • 输入:Ia+c
  • 输出:Ib'
  • 架构:由您决定(例如 U-Net、ResNet、...)
  • 损失:Ib'&之间的L1/L2损失Ib-D损失,Q损失
网络D
  • 输入:Ia+ Ib(真对),Ia+ Ib'(假对)
  • 输出:“虚假”标量/矩阵
  • 架构:由你决定(例如 PatchGAN)
  • 损失:“虚假”估计的交叉熵
网络Q
  • 输入:(Ib真实样本,用于训练Q),Ib'(假样本,反向传播时G
  • 输出:(c'估计类)
  • 架构:由您决定(例如 LeNet、ResNet、VGG、...)
  • c损失:和之间的交叉熵c'
训练阶段:
  1. 训练D一批真实对Ia+Ib然后训练一批假对Ia+ Ib'
  2. 训练Q一批真实样本Ib
  3. 固定DQ重量;
  4. Train G,将其生成的输出传递Ib'给它们DQ通过它们进行反向传播。

注意:这是一个非常粗略的架构描述。我建议您阅读文献([1, 5, 6, 7] 作为一个好的开始)以获得更多细节,也许是更详尽的解决方案。


参考

  1. 伊索拉、菲利普等人。“使用条件对抗网络进行图像到图像的翻译。” arXiv 预印本(2017 年)。http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
  2. 朱俊彦等人。“使用循环一致的对抗网络进行未配对的图像到图像转换。” arXiv 预印本 arXiv:1703.10593 (2017)。http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
  3. 米尔扎、迈赫迪和西蒙·奥辛德罗。“有条件的生成对抗网络。” arXiv 预印本 arXiv:1411.1784 (2014)。https://arxiv.org/pdf/1411.1784
  4. Goodfellow,伊恩等人。“生成对抗网络。” 神经信息处理系统的进展。2014. http://papers.nips.cc/paper/5423-generation-adversarial-nets.pdf
  5. 陈,习,等。“Infogan:通过信息最大化生成对抗网络的可解释表示学习。” 神经信息处理系统的进展。2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generation-adversarial-nets.pdf
  6. Lee、Minhyeok 和 Junhee Seok。“可控生成对抗网络”。arXiv 预印本 arXiv:1708.00598 (2017)。https://arxiv.org/pdf/1708.00598.pdf
  7. Odena、Augustus、Christopher Olah 和 Jonathon Shlens。“使用辅助分类器甘斯的条件图像合成。” arXiv 预印本 arXiv:1610.09585 (2016)。http://proceedings.mlr.press/v70/odena17a/odena17a.pdf

推荐阅读