首页 > 解决方案 > control_dependencies 不符合预期

问题描述

我希望 tensorflow 在f(...)

  1. 获取数据[索引]
  2. 缓存值
  3. 返回数据[索引]

tf.control_dependencies没有做我想做的事。

如何修复控制依赖?

结果:

cache_ 0.0
x_ 2.0
AssertionError

测试:

import tensorflow as tf
import numpy as np


def f(a, cache):
    assign_op = tf.assign(cache, a)
    with tf.control_dependencies([assign_op]):
        return a


def main():
    dtype = np.float32
    data = tf.range(5, dtype=dtype)
    cache = tf.Variable(0, dtype=dtype)
    x = f(data[2], cache)
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        x_ = sess.run(x)
        cache_ = sess.run(cache)
    print("cache_", cache_)
    print("x_", x_)
    assert np.allclose(cache_, x_)


main()

标签: pythontensorflow

解决方案


问题在于return aPython 代码。您没有在with块中创建任何 TensorFlow 操作。您可以使用它tf.identity来创建一个操作,以确保首先执行a读取的时间。assign_op这是更新的代码:

def f(a, cache):
    assign_op = tf.assign(cache, a)
    with tf.control_dependencies([assign_op]):
        return tf.identity(a)

推荐阅读