首页 > 解决方案 > 如何在 Tensorflow 的全局命名空间中存储和检索张量?

问题描述

我正在尝试阅读一个大型 Tensorflow 项目。对于计算图的节点分散在项目周围的项目,我想知道是否有一种方法可以存储计算图的 Tensor 节点并将该节点添加到 sess.run 中的 fetch 列表中?

例如,如果我想在 https://github.com/allenai/document-qa/blob/master/docqa/nn/span_prediction.py的第 615 行将概率添加到全局命名空间,是否有类似 tf. add_node(probs, "probs"),后来我可以得到 tf.get_node("probs"),只是为了方便在项目中传递节点。

一个更普遍的问题是,构建 tensorflow 代码并提高试验不同模型的效率会有什么更好的想法。

标签: tensorflow

解决方案


当然可以。要稍后检索它,您必须为其命名,以便您可以按名称检索它。以probs您的代码为例。它是使用tf.nn.softmax()函数创建的,其 API 如下所示。

tf.nn.softmax(
    logits,
    axis=None,
    name=None,
    dim=None
)

看到参数了name吗?您可以将此参数添加到第 615 行,如下所示:

 probs = tf.nn.softmax(all_logits, name='my_tensor')

稍后当你需要它时,你可以调用tf.Graph.get_tensor_by_name(name)来检索这个张量。

graph = tf.get_default_graph()
retrieved_probs = graph.get_tensor_by_name('my_tensor:0')

'my_tensor'softmax操作的名称,“:0”应该添加到它的末尾,这意味着您正在检索张量而不是操作。调用时Graph.get_operation_by_name(),不应添加 ':0'。

您必须确保张量存在(它可能在此行之前执行的代码中创建,或者可能从元图文件中恢复)。如果它是在变量范围内创建的,您还必须在name参数前面添加范围名称和“/”。例如,'my_scope/my_tensor:0'


推荐阅读