首页 > 解决方案 > 在tensorflow2.4.1中生成两次签名

问题描述

我正在使用 tensorflow2.4.1 并提高训练性能。我想在自己的类中将train_step改成AutoGraph,功能如下图:

def train_step(self, input_image, target, run_step):
    print( "Current running step: ", run_step )
    print( "-------------------train_step tracing~--------------------" )

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = self.generator(input_image, training=True)

        disc_real_output = self.discriminator(
            [input_image, target], training=True)
        disc_generated_output = self.discriminator(
            [input_image, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
            disc_generated_output, gen_output, target)

        disc_loss = self.discriminator_loss(
            disc_real_output, disc_generated_output, gen_output, target)

        # disc_loss is None in AutoGraph

    #tf.print(tf.shape(disc_real_output),tf.shape(disc_generated_output),tf.shape(gen_output),tf.shape(target) )
    #tf.print( gen_total_loss, gen_gan_loss, gen_l1_loss )
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                            self.generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                 self.discriminator.trainable_variables)

    self.generator_optimizer.apply_gradients(zip(generator_gradients,
                                                 self.generator.trainable_variables))
    self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                     self.discriminator.trainable_variables))

    self.train_gen_total_loss.update_state(gen_total_loss)
    self.train_gen_gan_loss.update_state(gen_gan_loss)
    self.train_gen_l1_loss.update_state(gen_l1_loss)
    self.train_disc_loss.update_state(disc_loss)

    return gen_total_loss, disc_loss

我发现两次打印在

print( "Current running step: ", run_step )
print( "-------------------train_step tracing~--------------------" )

但在我的 test_step 中,只打印一次。这是我的 test_step 的功能。

def test_step(self, input_image, target, run_step):
    print( "Current running step: ", run_step )
    print( "-------------------test_step tracing~--------------------" )

    gen_output = self.generator(input_image, training=False)

    disc_real_output = self.discriminator(
        [input_image, target], training=False)
    disc_generated_output = self.discriminator(
        [input_image, gen_output], training=False)

    gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
        disc_generated_output, gen_output, target)
    disc_loss = self.discriminator_loss(
        disc_real_output, disc_generated_output, gen_output, target)

    #tf.print(tf.shape(disc_real_output),tf.shape(disc_generated_output),tf.shape(gen_output),tf.shape(target) )
    #tf.print( gen_total_loss, gen_gan_loss, gen_l1_loss )
    self.test_gen_total_loss.update_state(gen_total_loss)
    self.test_gen_gan_loss.update_state(gen_gan_loss)
    self.test_gen_l1_loss.update_state(gen_l1_loss)
    self.test_disc_loss.update_state(disc_loss)

    return gen_total_loss, disc_loss

这是main函数中的运行代码。

            train_step_graph = tf.function(self.train_step) \
                                    .get_concrete_function(input_image, target, self.run_step)
            test_step_graph = tf.function(self.test_step) \
                                    .get_concrete_function(input_image, target, self.run_step)

我收到了这样的文字

在此处输入图像描述

我想知道为什么为 train_step 打印两次但为 test_step 打印一次。希望你的帮助。

标签: pythontensorflow2.0

解决方案


推荐阅读