首页 > 解决方案 > 如何在 TensorFlow 中获取张量的值(无需进行另一个会话)

问题描述

我正在寻找一种获得张量值的方法。在大多数情况下,可以通过调用“sess.run(target_op)”来解决问题。但是,我想知道另一种方式。我正在编辑从 GitHub 下载的代码,因此那里已经有一个正在运行的会话代码。在不触及会话运行部分的情况下,有没有办法获得一些特定的张量值?在我的例子中,代码是为了获得图像识别的准确性而构建的。在会话运行并进行准确性评估时,我还想在同一会话中获得“预测”张量值,而无需创建另一个会话。例如,像 tf.Print 这样的操作通过终端窗口显示张量值,而无需直接运行会话(在第一张图中,我们只需执行 sess.run(e) 即可从 c 中打印出张量) tf.Print 示例

a = tf.constant(5)
b = tf.constant(3)
c = tf.add(a,b)

#print tensor c (which is 8)
d = tf.Print(c,[c])
f = tf.constant(2)
e = tf.multiply(f,d)
sess = tf.Session()

#print operation can be executed without running the session directly
g = sess.run(e)`

像 tf.Print 一样,是否有任何操作可以在不直接运行会话的情况下获取张量值?(如第二个图) 我正在寻找的操作示例

更具体地说,我想要的是获取张量的值(使用实际的数字和数组,而不仅仅是“张量”数据结构)并将其传递给全局变量以在会话关闭后自由访问该值。该会话仅执行位于图表末尾的运算符,而我想要的值位于图表中间的张量。由于限制我不能创建比原始代码更多的会话,有没有办法获得特定的张量值?(我不能使用 .eval() 或 .run() 因为两者都需要访问“会话”。我正在编辑的代码使用 slim.evaluate_once 函数运行代码,并且由于 session() 绑定到该函数,我无法接近 session())

标签: pythonpython-2.7tensorflowtensor

解决方案


只要您输入适当的feed_dict. 例如说你想要一个被调用的张量biasAdd:0,你所谓的末端张量被称为prediction

然后你可以得到这个张量并评估它:

tensor = graph.get_tensor_by_name("biasAdd:0")
tensor_value, prediction_value = ses.run([tensor, prediction],... )

在 tensorflow 中,您必须使用 run 或 eval 从图中获取数值


推荐阅读