首页 > 解决方案 > tf.while_loop() 如何在 tensorflow 中执行它的主体部分

问题描述

import tensorflow as tf

n = tf.constant(3)
a = tf.constant(0)

def cond(a, n):
    return a < n

def body(a, n):
    print("box")
    return a+1, n

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    a, n = tf.while_loop(cond, body, [a, n])
    res = sess.run([a, n])
    print(res)



我希望结果是

box
box
box
[10, 10]

然而,结果是

box
[10, 10]

似乎该body功能只执行了一次。
但如果是这种情况,那么结果应该是[1, 10]而不是[10, 10].

我想知道为什么它看起来像这样以及 while_loop 如何执行它的主体部分。

标签: tensorflow

解决方案


创建图表时打印一次。图表创建后,在执行过程中,没有任何操作可打印。如果要在执行阶段打印,则必须明确指定要打印的操作。例如:

import tensorflow as tf

n = tf.constant(3)
a = tf.constant(0)

def cond(a, n):
    return a < n

def body(a, n):
    return tf.Print(a + 1, [], message='box'), n # <-- op for increment and print

a, n = tf.while_loop(cond, body, [a, n])

with tf.Session() as sess:

    res = sess.run([a, n])
    print(res)
# box
# box
# box
# [3, 3]

来自tf.Print()文档的注释:

注意:此操作打印到标准错误。它目前与 jupyter notebook 不兼容(打印到笔记本服务器的输出,而不是到笔记本中)。


推荐阅读