首页 > 解决方案 > 如何在 py_func 中访问 Tensorflow Graph?

问题描述

我正在尝试get_default_graphpy_func. 我的代码片段如下所示:

def test():
    tf.get_default_graph().get_tensor_by_name("Variable_1:0")


t = tf.py_func(test, [], tf.float32)

我收到一个错误:

KeyError:“名称‘Variable_1:0’指的是一个不存在的张量。操作‘Variable_1’在图中不存在。”

但我确实命名了一个名为Variable_1. 我的假设是在py_func.

如何使上述代码段起作用?

标签: pythontensorflowmachine-learning

解决方案


根据tf.py_func()的定义和给定示例,您需要传递输入并指定输出的数据类型(如果有多个)。

您的代码将需要像这样重写:

def test():
    # Do whatever computation and return a variable of type float

inp = tf.get_default_graph().get_tensor_by_name("Variable_1:0")
t = tf.py_func(test, [inp], tf.float32)

推荐阅读