首页 > 解决方案 > 为什么张量流中的for循环这么慢

问题描述

所以我知道这与 tensorflow 何时构建图表有关,但它做得不好......“有效”。这是我正在运行的虚拟代码:

@tf.function
def parTest(x_in):
    res = 0
    for i in range(5000):
        res += x_in + i
    return res

在不使用 tensorflow 的情况下运行该函数需要 0.002 秒,但使用 tensorflow 运行该函数需要 10 到 20 秒。这对我来说毫无意义,这是怎么回事?另外,我该如何解决?这里的 res 的实际值显然可以以更有效的方式计算,但我遇到的真正问题是我有一个 for 循环,其中每次迭代都有很多迭代可以彼此独立运行,但 tensorflow 拒绝这样做并让它们一个一个地运行很慢,就像这个虚拟示例一样。那么我如何告诉 tensorflow 不要这样做呢?

标签: pythontensorflowfor-loop

解决方案


循环在 TensorFlow 中从来都不是很有效。然而,这个函数对 TensorFlow 尤其不利,因为它会尝试静态地“展开”整个循环。也就是说,它不会将您的函数“翻译”成 a tf.while_loop,而是会在每次迭代中创建 5000 个操作副本。这是一个非常大的图表,除此之外,它总是会按顺序运行。实际上,我在 TensorFlow 2.2.0 中收到了一个警告,它指向这个信息页面:“警告:检测到大型展开循环”

如该链接中所述,问题在于 TensorFlow 无法(至少目前)检测任意迭代器上的循环,即使它们是简单的range,所以它只是在 Python 中运行循环并创建相应的操作。您可以通过编写tf.while_loop自己来避免这种情况,或者感谢AutoGraph,只需将 your 替换rangetf.range

import tensorflow as tf
@tf.function
def parTest(x_in):
    res = 0
    for i in tf.range(5000):
        res += x_in + i
    return res

尽管如此,编写自己的tf.while_loop(在绝对必要时,因为矢量化操作总是更快)可以让您更明确地控制parallel_iterations参数等细节。


推荐阅读