首页 > 解决方案 > 澄清 tf.Session() 作用域 TensorFlow

问题描述

我有一个 python 程序,我在其中定义了一个网络,并且像往常一样,我在我拥有的函数中对其进行训练

with tf.Session() as sess:
...
for epoch in xrange(num_epochs):
...
for n in xrange(num_batches):
       _, c = sess.run([optimizer, loss], feed_dict={....

在损失函数中,我必须做很多工作才能得到损失,特别是,我必须在张量中取最大值并用它来做一些事情。这里有一个例子

values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)
max_values = tf.reduce_max(values) # Tensor...
...

如果我使用调试,max_values它表示它是张量而不是值,所以如果我以这种方式更改我的代码,则将在前一段代码中创建的会话传递给函数

values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)
max_values = sess.run(tf.reduce_max(values)) # 2.0
...

有用。但是这个损失函数已经在会话的范围内,所以我的问题是为什么结果是张量而不是数字?有没有办法在不将会话传递给损失函数的情况下获取值?

标签: pythontensorflow

解决方案


根据文档

TensorFlow 使用 tf.Session 类来表示客户端程序(通常是 Python 程序,尽管在其他语言中也有类似的接口)和 C++ 运行时之间的连接。

这意味着当你这样做时,values = tf.constant([0, 1, 2, 0, 2], dtype=tf.float32)你只是将一个节点插入到你的 tensorflow 图表中!由于 Python 是用于低级 C++ 运行时的高级 API,因此您实际上需要一个 Session 来在这个低级运行时评估您的 Python 代码。

这就是为什么每次您需要计算或评估 Tensorflow 变量/方法/常量/等时,您都需要在 Session 中运行它tf.Session().run(yournode)

我希望它有帮助


推荐阅读