python - TensorFlow 的 while 循环比传统的 while 循环慢
问题描述
这是一个执行基本添加操作的常规 while 循环 -
import time
def check(a,b):
while(a<b):
a += 1
return [a,b]
a = 1
b = 1500000
start = time.time()
check(a,b)
print("Time = ",time.time() - start)
Time = 0.07060480117797852
这是使用 Tensorflow 的优化代码 -
import tensorflow as tf
def cond(t1, t2):
return tf.less(t1, t2)
def body(t1, t2):
return [tf.add(t1, 1), t2]
t1 = tf.constant(1)
t2 = tf.constant(1500000)
start = time.time()
res = tf.while_loop(cond, body, [t1, t2], parallel_iterations = 10)
print("Time = ",time.time() - start)
Time = 22.1693217754364
为什么优化后的代码比传统代码执行得慢。我有一个 6GB 内存的 GTX GPU。有任何想法吗?
解决方案
您不能指望 tf.while_loop 会比简单的 python 循环更快,例如
for( int i=0; i<1500000; i++)
j=j+1;
在 python、javaScript、c 等中总是会表现得更好。
tensorflow 针对矩阵运算进行了高度优化,而不是针对简单的循环。
我知道你只是在探索不同的方法,这很棒。
一种更张量流的方式是
t1 = tf.constant(1)
t2 = tf.constant(1500000)
start = time.time()
t3 = t1 + t2
print("Time = ",time.time() - start)
在我的机器上,时间 = 0.0010001659393310547
因此,tf 是线性代数运算的框架。如果您尝试将其用作通用语言框架,它不会做得很好。
PS。我想我以前在这里见过你 :-D 新年快乐!
推荐阅读
- javascript - 将时间转换为 ist
- python - 如何正确绘制训练和验证集的损失曲线?
- java - 为什么一段时间后我在 Java 中得到 403 状态码?
- javascript - 为什么我不断从对象数组的“减少”方法中获得“未定义”?
- python - 如何打乱矩阵的行及其相应的标签?
- php - Woocommerce - 在变体选择时禁用图像更改
- r - 在 R 中使用 Keras 和 Tensorflow 的 ValueError
- python-3.x - 我想要的是在 python 语言编码考试中对 stdin 给出的读取或使用输入的代码进行口头解释
- vim - 替换所有行中的模式,但仅在另一个模式之前
- python - 在 Anaconda 上将 Python 更新到 3.9