tensorflow - 在估计器模型函数中使用 tf.cond() 在 TPU 上训练 WGAN 会导致 global_step 加倍
问题描述
我正在尝试在 TPU 上训练 GAN,因此我一直在使用 TPUEstimator 类和随附的模型函数来尝试实现 WGAN 训练循环。我正在尝试tf.cond
将 TPUEstimatorSpec 的两个训练操作合并为:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
并且critic_opt
是我正在使用的优化器的最小化功能,也设置为更新全局步骤。CRITIC_UPDATES_PER_GEN_UPDATE
是一个 Python 常量,它是 WGAN 训练的一部分。我尝试使用 找到 GAN 模型tf.cond
,但所有模型都使用tf.group
,我不能使用它,因为您需要比生成器优化更多次批评者。但是,每次运行 100 个批次,全局步长根据检查点编号增加 200。我的模型是否仍在正确训练,或者tf.cond
不应该以这种方式用于训练 GAN?
解决方案
tf.cond
不应该以这种方式用于训练 GAN。
您得到 200,因为每个训练步骤都会评估true_fn
和的副作用(如分配操作) 。false_fn
副作用之一是tf.assign_add
两个优化器定义的全局步骤操作。
因此,发生的事情就像
- 执行
global_step++ (gen_opt)
和global_step++ (critic_op)
- 病情评估
- 执行
true_fn
身体或false_fn
身体(取决于条件)。
如果你想使用 来训练 GAN tf.cond
,你必须从true_fn
/的外部移除所有的辅助操作(比如赋值,因此优化步骤的定义),false_fn
并在其中声明所有内容。
作为参考,您可以看到有关以下行为的答案tf.cond
:https ://stackoverflow.com/a/37064128/2891324
推荐阅读
- c# - 未经确认的电子邮件登录:尽管“options.SignIn.RequireConfirmedEmail = false;”,但 PasswordSignInAsync 返回“NotAllowed”结果
- html - 使 div 不影响其他 div 的大小
- javascript - 如何对设置超时的功能做出反应
- javascript - 在网页中嵌入 HTML 文档,而不是来自 URL
- python - 为什么keras说我只有一个标签而不是三个
- c - 尝试编译 C 文件时出错:mkfifo:无法创建 fifo 'stderr':不支持操作
- python - brew 安装最新的 python3 但 python3 没有更新?
- npm - GatsbyJS - 无法解析“babel-runtime/helpers/possibleConstructorReturn”
- python - 如何为 aws lambda 处理函数(Python)编写单元测试
- html - 使图像的底部有点暗