tensorflow - TensorFlow 的 tf.control_dependencies 无法按预期工作
问题描述
TensorFlow 版本为 1.14.0
我的示例代码如下:
import sys
import tensorflow as tf
from tensorflow.python import debug as tf_debug
sess = tf.Session()
m = tf.get_variable('m', initializer=tf.constant(1.0))
n = tf.get_variable('n', initializer=tf.constant(2.0))
# z = m*m + n*n
z = (m*n) ** 2
g1 = tf.gradients(z, [m, n])
r = 0.3 * g1[0]
attack_op = m.assign(m + r)
# attack_op = m.assign_add(r)
with tf.control_dependencies([attack_op]):
z_a = (m*n) ** 2
g2 = tf.gradients(z_a, [m, n])
with tf.control_dependencies(g2):
restore_op = m.assign(m - r)
with tf.control_dependencies([restore_op]):
g3 = [tf.add_n(x) for x in zip(g1, g2)]
g4 = [x / 2 for x in g3]
sess.run(tf.global_variables_initializer())
print('m: %g' % sess.run(m))
print('n: %g' % sess.run(n))
print(sess.run([g1, g2, g3, g4]))
以下代码的结果是:
m: 1
n: 2
[[8.0, 4.0], [27.2, 46.24], [35.2, 50.24], [8.0, 4.0]]
但是g4的手动计算结果是:[17.6, 25.12]
这个奇怪结果的原因是什么?谢谢
解决方案
推荐阅读
- azure - 如何在开发 Azure Function 版本 2 时在 local.settings.json 中配置“ConsumerGroup”
- javascript - 检测 webapp 上是否启用了 UDP 或 TCP
- java - 在 Activity 被销毁时删除 Auth Firebase 用户
- qt-creator - QtCreator 中的 Google cpplint 集成
- python - 在 Pandas 中将包含年份和周数的字符串转换为日期时间
- dart - 在颤动中使用底部导航栏在页面之间传递数据
- assembly - 具有无限个参数 MIPS 的函数
- javascript - 如何使用 react-testing-library 中的 `wait` 来测试不处理被拒绝承诺的组件?
- arrays - 在本机反应中未在渲染中设置状态数组值
- android - 在自定义视图类中播放 GIF