tensorflow - 如何在 Tensorflow 的全局命名空间中存储和检索张量?
问题描述
我正在尝试阅读一个大型 Tensorflow 项目。对于计算图的节点分散在项目周围的项目,我想知道是否有一种方法可以存储计算图的 Tensor 节点并将该节点添加到 sess.run 中的 fetch 列表中?
例如,如果我想在 https://github.com/allenai/document-qa/blob/master/docqa/nn/span_prediction.py的第 615 行将概率添加到全局命名空间,是否有类似 tf. add_node(probs, "probs"),后来我可以得到 tf.get_node("probs"),只是为了方便在项目中传递节点。
一个更普遍的问题是,构建 tensorflow 代码并提高试验不同模型的效率会有什么更好的想法。
解决方案
当然可以。要稍后检索它,您必须为其命名,以便您可以按名称检索它。以probs
您的代码为例。它是使用tf.nn.softmax()
函数创建的,其 API 如下所示。
tf.nn.softmax(
logits,
axis=None,
name=None,
dim=None
)
看到参数了name
吗?您可以将此参数添加到第 615 行,如下所示:
probs = tf.nn.softmax(all_logits, name='my_tensor')
稍后当你需要它时,你可以调用tf.Graph.get_tensor_by_name(name)
来检索这个张量。
graph = tf.get_default_graph()
retrieved_probs = graph.get_tensor_by_name('my_tensor:0')
'my_tensor'
是softmax操作的名称,“:0”应该添加到它的末尾,这意味着您正在检索张量而不是操作。调用时Graph.get_operation_by_name()
,不应添加 ':0'。
您必须确保张量存在(它可能在此行之前执行的代码中创建,或者可能从元图文件中恢复)。如果它是在变量范围内创建的,您还必须在name
参数前面添加范围名称和“/”。例如,'my_scope/my_tensor:0'
。
推荐阅读
- cron - Openssl 命令不会仅针对特定证书在 crontab 中运行
- python - 如何获取记录第一次更改的索引数据框
- sql-server - 在 SQL 中删除单个记录需要花费大量时间
- reactjs - 使用 TypeScript 动态格式化反应表中的标题
- teamcity - TeamCity AssemblyInfo Patcher:如何使用三位数的版本?
- c# - 使用 .Net 我如何以编程方式导航到网页,通过代码与之交互,然后从新生成的页面中获取特定值
- delphi - FindFirst 找不到所有扩展名为 ~1~ 的文件和其他文件
- r - 分配反应式两次忽略第一个
- stata - 使用 Stata 中的时间序列运算符将缺失的观测值替换为以前的值时类型不匹配
- flutter - Flutter setState()从有状态的孩子调用时不刷新小部件