首页 > 解决方案 > 具有额外自定义损失函数的 GAN

问题描述

下面展示了一个标准的 GAN 模型:

def define_gan(g_model, d_model, image_shape):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # define the source image
    in_src = Input(shape=image_shape)
    # connect the source image to the generator input
    gen_out = g_model(in_src)
    # connect the source input and generator output to the discriminator input
    dis_out = d_model([in_src, gen_out])
    # src image as input, generated image and classification output
    model = Model(in_src, [dis_out, gen_out])
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
    return model

我想为模型添加额外的损失,如下所示:

def define_gan(g_model, d_model, image_shape, p_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # define the source image
    in_src = Input(shape=image_shape)
    # connect the source image to the generator input
    gen_out = g_model(in_src)
    # connect the source input and generator output to the discriminator input
    dis_out = d_model([in_src, gen_out])
    # connect the p output to the generator output
    p_out = p_model(gen_out)
    # src image as input, generated image and classification output
    model = Model(in_src, [dis_out, gen_out, p_out])
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae', 'mse'], optimizer=opt, loss_weights=[1,100,100])
    return model

可以看出,我添加了p模型,实际上并不是keras模型。它是一个函数,它获取 gan_out 生成的输出并将其与其他模型进行比较以测量额外损失。一小部分功能如下:

def define_p(image_shape):
    in_image = Input(shape=image_shape)
    u1 = np.reshape(in_image[:,:,0], (N))
    u2 = np.reshape(in_image[:,:,1], (N))
    u3 = np.reshape(in_image[:,:,2], (N))
    # compute two specific matrices (M, N) based on u1, u2, u3
    # ...
    out_image = M - N
    model = Model(in_image, out_image)
    return model

我面临的主要问题是 in_image 的尺寸为 (?,64,64,3),我无法对其进行整形,然后对每个批次分别对每个 u1、u2、u3 进行操作。非常感谢任何帮助!

标签: pythonkerasreshapelossgenerative-adversarial-network

解决方案


推荐阅读