tensorflow - tf.while_loop() 如何在 tensorflow 中执行它的主体部分
问题描述
import tensorflow as tf
n = tf.constant(3)
a = tf.constant(0)
def cond(a, n):
return a < n
def body(a, n):
print("box")
return a+1, n
with tf.Session() as sess:
tf.global_variables_initializer().run()
a, n = tf.while_loop(cond, body, [a, n])
res = sess.run([a, n])
print(res)
我希望结果是
box
box
box
[10, 10]
然而,结果是
box
[10, 10]
似乎该body
功能只执行了一次。
但如果是这种情况,那么结果应该是[1, 10]
而不是[10, 10]
.
我想知道为什么它看起来像这样以及 while_loop 如何执行它的主体部分。
解决方案
创建图表时打印一次。图表创建后,在执行过程中,没有任何操作可打印。如果要在执行阶段打印,则必须明确指定要打印的操作。例如:
import tensorflow as tf
n = tf.constant(3)
a = tf.constant(0)
def cond(a, n):
return a < n
def body(a, n):
return tf.Print(a + 1, [], message='box'), n # <-- op for increment and print
a, n = tf.while_loop(cond, body, [a, n])
with tf.Session() as sess:
res = sess.run([a, n])
print(res)
# box
# box
# box
# [3, 3]
来自tf.Print()文档的注释:
注意:此操作打印到标准错误。它目前与 jupyter notebook 不兼容(打印到笔记本服务器的输出,而不是到笔记本中)。
推荐阅读
- c# - XNA 框架枚举中为键调用的“@”符号是什么?
- javascript - 在 AWS Lambda 中获取正文请求
- python - 制作蛇游戏(按照教程):我收到错误,对象不可下标
- iis - 如何区分 IIS 的状态码和 Web API 应用程序的状态码
- android - 无法解析符号“默认”
- angular - 在同一级别共享 Angular 子到子数据
- jquery - 如何将 jquery 从 Spring Boot 获取的图像显示到 DOM 上
- java - 可执行jar找不到打包资源
- elasticsearch - 通过 Logstash 将 JSON 数据记录到 Elasticsearch 是否有附加价值?
- docker - 命令“/bin/sh -c conda update conda”返回非零代码:127