python - 如何获取对 ReadVariableOp 正在读取的变量对象的引用?
问题描述
我有以下形式的代码:
project_ops = []
for op in tf.get_default_graph().get_operations():
if op.type == 'Conv2D':
activations, kernel = op.inputs
batch_size, height, width, num_channels = activations.shape.as_list()
kernel_size_height, kernel_size_width, input_channels, output_channels = kernel.shape.as_list()
print(activations.shape.as_list(), kernel.shape.as_list())
project_ops.append(tf.assign(kernel, Orthoganalize(kernel, [height, width])))
这不起作用,因为kernel
它不是变量而是 ReadVariableOp。我希望从中获取变量,但它似乎没有对 Python 中可访问的变量的引用?
解决方案
这可能不是一种简洁的方式,但您可以使用以下功能:
def get_var_from(readOp): # make sure to pass a tf.Operation with
# readOp.type == 'ReadVariableOp'
handleOp = readOp.inputs[0].op
for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
if var.op is handleOp:
return var
这需要一个 tf.Variable (*) 唯一的 VarHandleOp 实例,并读取所有变量以找到使用句柄的变量。(注意,ReadVariableOp 不是变量唯一的。可用于重新读取变量。)
我看不到从句柄到 tf.Variable -instance 的链接,Python API 中可能没有,也许它是一个单向链接。
(*) 据我了解,您使用的是 ResourceVariables,因此此描述可能仅指 ResourceVariables。
推荐阅读
- html - 如何根据视口中的部分将类添加到正文标签
- c# - Unity 在 Visual Studio 中找不到使用 Nuget 安装的包
- angular - 加载程序 app.component.html 中的 Nativescript 错误没有返回字符串
- javascript - 如何实现水平滚动flatlist react native
- date - 将天数添加到日期时间并将其转换为日期?- Odoo V14
- sql - DART SQL 如何从 Dart 中的 SQL 查询结果中读取多行值(多列)?
- instagram - 有没有收集公共 Instagram 统计数据的 Instagram API?
- ios - 为 iOS 构建时获得棕色精灵
- angular - 管道角度变换内容标签
- php - laravel 验证,存储和更新唯一值,尝试获取非对象的属性