首页 > 解决方案 > 什么可能导致在张量上运行 .eval() 永远不会结束?

问题描述

我目前正在阅读有关如何使用 TensorFlow 训练 CNN 并使用它根据 CIFAR-10 数据集对图像进行分类的教程。运行评估脚本 cifar10_eval.py 时,输出是模型相对于测试集的准确度的精度等级。相反,我想查看测试数据上每个类别的模型分类输出。logits的计算和存储方式是通过:

# Build a graph that computes the logits predictions from the 
# inference model. 
logits = cifar10.inference(images)

运行此行后,我编辑了脚本以通过以下方式显示“logits”变量的类型、形状和元素的类型:

print(type(logits))
print(logits.dtype)
print(logits.shape)

它返回以下输出:

类'tensorflow.python.framework.ops.Tensor'

数据类型:'float32'

(128, 10)

我假设形状是 (128,10),因为有 128 张测试图像,每张图像都被评估为 10 个类别中的每一个的可能性。为了显示这一点,我正在尝试以下代码:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(logits.eval())

这个 .eval() 语句永远不会终止,我想知道我哪里出了问题以及如何解决这个问题以便我可以访问 logits?

标签: pythontensorflow

解决方案


这可能是因为您正在打开一个新会话(并用它重新初始化变量!)。尝试在创建它的同一会话中进行评估logits。但奇怪的是它没有终止,它应该引发错误。另外,这tf.Session()不是tf.Session


推荐阅读