首页 > 解决方案 > 张量流测试中的渴望和图形执行

问题描述

我有一些适用于图形和会话的测试。我还想用渴望模式编写一些小测试来轻松测试一些功能。例如:

def test_normal_execution():
    matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
    dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
    iterator = dataset.make_one_shot_iterator()
    first_elem = iterator.get_next()
    with tf.Session() as sess:
        result = sess.run(first_elem)
        assert (result == [1, 2, 3, 4]).all()
    sess.close()

在另一个文件中:

def test_eager_execution():
    matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
    tf.enable_eager_execution()
    dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
    iterator = dataset.__iter__()
    first_elem = iterator.next()
    assert (first_elem.numpy() == [1, 2, 3, 4]).all() 

有没有办法解决这个问题?ValueError: tf.enable_eager_execution must be called at program startup.当我尝试运行急切执行的测试时,我得到了。我pytest用来运行我的测试。

编辑

在接受响应的帮助下,我创建了一个装饰器,它可以很好地与 Eager 模式和 pytest 的固定装置配合使用:

def run_eagerly(func):
    @functools.wraps(func)
    def eager_fun(*args, **kwargs):
        with tf.Session() as sess:
            sess.run(tfe.py_func(func, inp=list(kwargs.values()), Tout=[]))

    return eager_fun

标签: unit-testingtensorflow

解决方案


需要注意的是tf.contrib命名空间中的任何内容都可能在版本之间发生变化,您可以使用@tf.contrib.eager.run_test_in_graph_and_eager_modes. 其他一些项目,比如 TensorFlow Probability似乎使用了这个

对于非测试,需要研究的一些事情是:

  • tf.contrib.eager.defun:当您启用了急切执行但想要将一些计算“编译”到图形中以从内存和/或性能优化中受益时很有用。
  • tf.contrib.eager.py_func:当没有启用急切执行但想在图中以 Python 函数的形式执行一些计算时很有用。

人们可能会质疑不允许tf.enable_eager_execution()撤消呼叫的原因。这个想法是库作者不应该调用它,只有最终用户应该在main(). 这减少了以不兼容方式编写库的可能性(例如,一个库中的函数禁用急切执行并返回符号张量,而另一个库中的函数启用急切执行并期望具体的值张量。这会使库的混合成为问题)。

希望有帮助


推荐阅读