首页 > 解决方案 > 加载 tf.keras.Model 子类调用超类中的 step_function

问题描述

我正在使用 keras 实现 GAN,其中我覆盖了 train_step 函数来自定义我自己的训练循环。当我创建模型并对其进行拟合时,一切正常。但是,当我加载之前保存的模型并尝试对其进行拟合时,调用的 train_step 函数属于 keras.Model 超类,这会导致抛出以下错误:

    ValueError: No gradients provided for any variable: ['sequential/conv2d/kernel:0', 'sequential_1/conv2d_1/kernel:0', 'sequential_1/batch_normalization/gamma:0', 'sequential_1/batch_normalization/beta:0', 'sequential_2/conv2d_2/kernel:0', 'sequential_2/batch_normalization_1/gamma:0', 'sequential_2/batch_normalization_1/beta:0', 'sequential_3/conv2d_3/kernel:0', 'sequential_3/batch_normalization_2/gamma:0', 'sequential_3/batch_normalization_2/beta:0', 'sequential_4/conv2d_4/kernel:0', 'sequential_4/batch_normalization_3/gamma:0', 'sequential_4/batch_normalization_3/beta:0', 'sequential_5/conv2d_5/kernel:0', 'sequential_5/batch_normalization_4/gamma:0', 'sequential_5/batch_normalization_4/beta:0', 'sequential_6/conv2d_6/kernel:0', 'sequential_6/batch_normalization_5/gamma:0', 'sequential_6/batch_normalization_5/beta:0', 'sequential_7/conv2d_7/kernel:0', 'sequential_7/batch_normalization_6/gamma:0', 'sequential_7/batch_normalization_6/beta:0', 'sequential_8/conv2d_transpose/kernel:0', 'sequential_8/batch_normalization_7/gamma:0', 'sequential_8/batch_normalization_7/beta:0', 'sequential_9/conv2d_transpose_1/kernel:0', 'sequential_9/batch_normalization_8/gamma:0', 'sequential_9/batch_normalization_8/beta:0', 'sequential_10/conv2d_transpose_2/kernel:0', 'sequential_10/batch_normalization_9/gamma:0', 'sequential_10/batch_normalization_9/beta:0', 'sequential_11/conv2d_transpose_3/kernel:0', 'sequential_11/batch_normalization_10/gamma:0', 'sequential_11/batch_normalization_10/beta:0', 'sequential_12/conv2d_transpose_4/kernel:0', 'sequential_12/batch_normalization_11/gamma:0', 'sequential_12/batch_normalization_11/beta:0', 'sequential_13/conv2d_transpose_5/kernel:0', 'sequential_13/batch_normalization_12/gamma:0', 'sequential_13/batch_normalization_12/beta:0', 'sequential_14/conv2d_transpose_6/kernel:0', 'sequential_14/batch_normalization_13/gamma:0', 'sequential_14/batch_normalization_13/beta:0', 'conv2d_transpose_7/kernel:0', 'conv2d_transpose_7/bias:0', 'sequential_15/conv2d_8/kernel:0', 'sequential_16/conv2d_9/kernel:0', 'sequential_16/batch_normalization_14/gamma:0', 'sequential_16/batch_normalization_14/beta:0', 'sequential_17/conv2d_10/kernel:0', 'sequential_17/batch_normalization_15/gamma:0', 'sequential_17/batch_normalization_15/beta:0', 'conv2d_11/kernel:0', 'batch_normalization_16/gamma:0', 'batch_normalization_16/beta:0', 'conv2d_12/kernel:0', 'conv2d_12/bias:0'].

训练脚本如下所示:

if __name__ == '__main__':
    ds_train = create_dataset('flic', test=False, batch_size=32)
    ds_test = create_dataset('flic', test=True, batch_size=32)

    model_path = './model/gan'

    if os.path.exists(model_path):
        print('Model found in disk, restoring...')
        model = load_model(model_path)
    else:
        print('Model not found in disk, creating new one...')
        os.makedirs(model_path)
        generator = build_generator(img_width=256, img_height=256, output_channels=2)
        discriminator = build_discriminator(img_width=256, img_height=256, output_channels=2)
        model = PatchGAN(generator=generator, discriminator=discriminator)

    model.compile()

    model.fit(
        ds_train,
        epochs=1,
        steps_per_epoch=1,
        callbacks=[
            TensorBoard(),
            LambdaCallback(on_epoch_end=lambda epoch, logs: save_model(model, model_path)),
            LambdaCallback(on_epoch_end=lambda epoch, logs: preview_output(model.generator, ds_test))
        ]
    )

GAN 代码如下所示:

    class PatchGAN(tf.keras.Model):

    def __init__(self,
                 generator,
                 discriminator,
                 lamb=100,
                 loss_function=BinaryCrossentropy(from_logits=True),
                 generator_optimizer=Adam(learning_rate=0.0003),
                 discriminator_optimizer=Adam(learning_rate=0.0003),
                 *args,
                 **kwargs):

        super().__init__(*args, **kwargs)
        self.lamb = lamb
        self.generator = generator
        self.discriminator = discriminator
        self.loss_function = loss_function
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self._set_inputs(generator.inputs)

    def call(self, inputs, training=None, mask=None):
        return self.generator(inputs)

    def train_step(self, data):
        input_image, target = data

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_output = self.generator(input_image, training=True)
            disc_real_output = self.discriminator([input_image, target], training=True)
            disc_generated_output = self.discriminator([input_image, gen_output], training=True)

            gan_loss = self.loss_function(tf.ones_like(disc_generated_output), disc_generated_output)
            l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
            total_gen_loss = gan_loss + (self.lamb * l1_loss)

            real_loss = self.loss_function(tf.ones_like(disc_real_output), disc_real_output)
            generated_loss = self.loss_function(tf.zeros_like(disc_generated_output), disc_generated_output)
            total_disc_loss = real_loss + generated_loss

        generator_gradients = gen_tape.gradient(total_gen_loss, self.generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(total_disc_loss, self.discriminator.trainable_variables)

        self.generator_optimizer.apply_gradients(
            zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, self.discriminator.trainable_variables))

        return {
            "d_loss": total_disc_loss,
            "g_loss": total_gen_loss,
            "gen_gan_loss": gan_loss,
            "gen_l1_loss": l1_loss
        }


def _downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(filters, size, strides=2,
                               padding='same', kernel_initializer=initializer,
                               use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result


def _upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same', kernel_initializer=initializer,
                                        use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    result.add(tf.keras.layers.ReLU())
    return result


def build_generator(img_width, img_height, output_channels):
    inputs = tf.keras.layers.Input(shape=[img_width, img_height, 1])

    down_stack = [
        _downsample(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
        _downsample(128, 4),  # (bs, 64, 64, 128)
        _downsample(256, 4),  # (bs, 32, 32, 256)
        _downsample(512, 4),  # (bs, 16, 16, 512)
        _downsample(512, 4),  # (bs, 8, 8, 512)
        _downsample(512, 4),  # (bs, 4, 4, 512)
        _downsample(512, 4),  # (bs, 2, 2, 512)
        _downsample(512, 4),  # (bs, 1, 1, 512)
    ]

    up_stack = [
        _upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
        _upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
        _upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
        _upsample(512, 4),  # (bs, 16, 16, 1024)
        _upsample(256, 4),  # (bs, 32, 32, 512)
        _upsample(128, 4),  # (bs, 64, 64, 256)
        _upsample(64, 4),  # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(output_channels, 4,
                                           strides=2,
                                           padding='same',
                                           kernel_initializer=initializer,
                                           activation='tanh')  # (bs, 256, 256, 3)
    x = inputs
    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])
    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)


def build_discriminator(img_width, img_height, output_channels):
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[img_width, img_height, 1], name='input_image')
    tar = tf.keras.layers.Input(shape=[img_width, img_height, output_channels], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

    down1 = _downsample(128, 4, False)(x)  # (bs, 128, 128, 64)
    down2 = _downsample(256, 4)(down1)  # (bs, 64, 64, 128)
    down3 = _downsample(256, 4)(down2)  # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                  kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)
    last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

有人知道为什么会这样吗?

标签: tensorflowkeraskeras-layer

解决方案


推荐阅读