首页 > 解决方案 > 在估计器模型函数中使用 tf.cond() 在 TPU 上训练 WGAN 会导致 global_step 加倍

问题描述

我正在尝试在 TPU 上训练 GAN,因此我一直在使用 TPUEstimator 类和随附的模型函数来尝试实现 WGAN 训练循​​环。我正在尝试tf.cond将 TPUEstimatorSpec 的两个训练操作合并为:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_opt并且critic_opt是我正在使用的优化器的最小化功能,也设置为更新全局步骤。CRITIC_UPDATES_PER_GEN_UPDATE是一个 Python 常量,它是 WGAN 训练的一部分。我尝试使用 找到 GAN 模型tf.cond,但所有模型都使用tf.group,我不能使用它,因为您需要比生成器优化更多次批评者。但是,每次运行 100 个批次,全局步长根据检查点编号增加 200。我的模型是否仍在正确训练,或者tf.cond不应该以这种方式用于训练 GAN?

标签: tensorflowgenerative-adversarial-networktpu

解决方案


tf.cond不应该以这种方式用于训练 GAN。

您得到 200,因为每个训练步骤都会评估true_fn和的副作用(如分配操作) 。false_fn副作用之一是tf.assign_add两个优化器定义的全局步骤操作。

因此,发生的事情就像

  • 执行global_step++ (gen_opt)global_step++ (critic_op)
  • 病情评估
  • 执行true_fn身体或false_fn身体(取决于条件)。

如果你想使用 来训练 GAN tf.cond,你必须从true_fn/的外部移除所有的辅助操作(比如赋值,因此优化步骤的定义),false_fn并在其中声明所有内容。

作为参考,您可以看到有关以下行为的答案tf.condhttps ://stackoverflow.com/a/37064128/2891324


推荐阅读