python - 为什么张量流中的for循环这么慢
问题描述
所以我知道这与 tensorflow 何时构建图表有关,但它做得不好......“有效”。这是我正在运行的虚拟代码:
@tf.function
def parTest(x_in):
res = 0
for i in range(5000):
res += x_in + i
return res
在不使用 tensorflow 的情况下运行该函数需要 0.002 秒,但使用 tensorflow 运行该函数需要 10 到 20 秒。这对我来说毫无意义,这是怎么回事?另外,我该如何解决?这里的 res 的实际值显然可以以更有效的方式计算,但我遇到的真正问题是我有一个 for 循环,其中每次迭代都有很多迭代可以彼此独立运行,但 tensorflow 拒绝这样做并让它们一个一个地运行很慢,就像这个虚拟示例一样。那么我如何告诉 tensorflow 不要这样做呢?
解决方案
循环在 TensorFlow 中从来都不是很有效。然而,这个函数对 TensorFlow 尤其不利,因为它会尝试静态地“展开”整个循环。也就是说,它不会将您的函数“翻译”成 a tf.while_loop
,而是会在每次迭代中创建 5000 个操作副本。这是一个非常大的图表,除此之外,它总是会按顺序运行。实际上,我在 TensorFlow 2.2.0 中收到了一个警告,它指向这个信息页面:“警告:检测到大型展开循环”。
如该链接中所述,问题在于 TensorFlow 无法(至少目前)检测任意迭代器上的循环,即使它们是简单的range
,所以它只是在 Python 中运行循环并创建相应的操作。您可以通过编写tf.while_loop
自己来避免这种情况,或者感谢AutoGraph,只需将 your 替换range
为tf.range
:
import tensorflow as tf
@tf.function
def parTest(x_in):
res = 0
for i in tf.range(5000):
res += x_in + i
return res
尽管如此,编写自己的tf.while_loop
(在绝对必要时,因为矢量化操作总是更快)可以让您更明确地控制parallel_iterations
参数等细节。
推荐阅读
- terminal - Visual Studio 代码:以退出代码终止:3221225477
- r - 如何将一列数据附加到R中的多列
- android - 如何在意图android中传递2个文件
- php - 如何在 Twig 中的元素上添加多个 ID
- r - R中的ifelse命令没有填充数据框中的列?
- reactjs - 尝试通过道具向子级提升功能时超出最大更新深度
- sql-server - SQL Server 如何使用“。”映射我的服务器。作为服务器名称?
- javascript - 与多个对象数组相交
- r - 如果不是所有列都存在于 R 中的所有数据框中,则按列合并数据框
- python - 如何在while循环中从列表中删除元组