首页 > 解决方案 > 如何获取对 ReadVariableOp 正在读取的变量对象的引用?

问题描述

我有以下形式的代码:

project_ops = []
for op in tf.get_default_graph().get_operations():
  if op.type == 'Conv2D':
    activations, kernel = op.inputs
    batch_size, height, width, num_channels = activations.shape.as_list()
    kernel_size_height, kernel_size_width, input_channels, output_channels = kernel.shape.as_list()
    print(activations.shape.as_list(), kernel.shape.as_list())
    project_ops.append(tf.assign(kernel, Orthoganalize(kernel, [height, width])))

这不起作用,因为kernel它不是变量而是 ReadVariableOp。我希望从中获取变量,但它似乎没有对 Python 中可访问的变量的引用?

标签: pythontensorflow

解决方案


这可能不是一种简洁的方式,但您可以使用以下功能:

def get_var_from(readOp): # make sure to pass a tf.Operation with
                          # readOp.type == 'ReadVariableOp'
    handleOp = readOp.inputs[0].op
    for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        if var.op is handleOp:
            return var

这需要一个 tf.Variable (*) 唯一的 VarHandleOp 实例,并读取所有变量以找到使用句柄的变量。(注意,ReadVariableOp 不是变量唯一的。可用于重新读取变量。)

我看不到从句柄到 tf.Variable -instance 的链接,Python API 中可能没有,也许它是一个单向链接。

(*) 据我了解,您使用的是 ResourceVariables,因此此描述可能仅指 ResourceVariables。


推荐阅读