python - 在tensorflow2.4.1中生成两次签名
问题描述
我正在使用 tensorflow2.4.1 并提高训练性能。我想在自己的类中将train_step改成AutoGraph,功能如下图:
def train_step(self, input_image, target, run_step):
print( "Current running step: ", run_step )
print( "-------------------train_step tracing~--------------------" )
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = self.generator(input_image, training=True)
disc_real_output = self.discriminator(
[input_image, target], training=True)
disc_generated_output = self.discriminator(
[input_image, gen_output], training=True)
gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
disc_generated_output, gen_output, target)
disc_loss = self.discriminator_loss(
disc_real_output, disc_generated_output, gen_output, target)
# disc_loss is None in AutoGraph
#tf.print(tf.shape(disc_real_output),tf.shape(disc_generated_output),tf.shape(gen_output),tf.shape(target) )
#tf.print( gen_total_loss, gen_gan_loss, gen_l1_loss )
generator_gradients = gen_tape.gradient(gen_total_loss,
self.generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
self.discriminator.trainable_variables)
self.generator_optimizer.apply_gradients(zip(generator_gradients,
self.generator.trainable_variables))
self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
self.discriminator.trainable_variables))
self.train_gen_total_loss.update_state(gen_total_loss)
self.train_gen_gan_loss.update_state(gen_gan_loss)
self.train_gen_l1_loss.update_state(gen_l1_loss)
self.train_disc_loss.update_state(disc_loss)
return gen_total_loss, disc_loss
我发现两次打印在
print( "Current running step: ", run_step )
print( "-------------------train_step tracing~--------------------" )
但在我的 test_step 中,只打印一次。这是我的 test_step 的功能。
def test_step(self, input_image, target, run_step):
print( "Current running step: ", run_step )
print( "-------------------test_step tracing~--------------------" )
gen_output = self.generator(input_image, training=False)
disc_real_output = self.discriminator(
[input_image, target], training=False)
disc_generated_output = self.discriminator(
[input_image, gen_output], training=False)
gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
disc_generated_output, gen_output, target)
disc_loss = self.discriminator_loss(
disc_real_output, disc_generated_output, gen_output, target)
#tf.print(tf.shape(disc_real_output),tf.shape(disc_generated_output),tf.shape(gen_output),tf.shape(target) )
#tf.print( gen_total_loss, gen_gan_loss, gen_l1_loss )
self.test_gen_total_loss.update_state(gen_total_loss)
self.test_gen_gan_loss.update_state(gen_gan_loss)
self.test_gen_l1_loss.update_state(gen_l1_loss)
self.test_disc_loss.update_state(disc_loss)
return gen_total_loss, disc_loss
这是main函数中的运行代码。
train_step_graph = tf.function(self.train_step) \
.get_concrete_function(input_image, target, self.run_step)
test_step_graph = tf.function(self.test_step) \
.get_concrete_function(input_image, target, self.run_step)
我收到了这样的文字
我想知道为什么为 train_step 打印两次但为 test_step 打印一次。希望你的帮助。
解决方案
推荐阅读
- python - 从 Sqlalchemy 中的 ORM 获取查询语句
- python - ModuleNotFoundError:没有名为“graphframes”的模块
- jenkins - Groovy - 将脚本的输出存储到变量中
- python - 抓取 Airbnb 数据 - Beautifulsoup 输出到 csv 文件
- python - 使用 python regex 获取 raw_text 的两个换行符(\n)之间的所有文本
- c++ - 指定隐式转换之间的优先级
- java - Smallest path from corner to corner of a 2D array
- reactjs - ReactJS 中的自动绑定函数
- django-models - 如何创建本地 id 作为实例模型
- postgresql - 如何按搜索字符串postgresql的第一个单词对查询结果进行排序?