tensorflow - TensorFlow:双射器中的变量不能重用
问题描述
描述问题
我试图在 MaskedAutoregressiveFlow 双射器中重用神经网络中的权重和偏差,方法是将其放在tf.variable_scope
with 中reuse=tf.AUTO_REUSE
。但是发现在实践中没有重用权重和偏差。
复制
import tensorflow as tf
from tensorflow.contrib.distributions.python.ops import bijectors as tfb
def get_bijector(name='my_bijector', reuse=None):
"""Returns a MAF bijector."""
with tf.variable_scope(name, reuse=reuse):
shift_and_log_scale_fn = \
tfb.masked_autoregressive_default_template([128])
return tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn)
x = tf.placeholder(shape=[None, 64], dtype='float32', name='x')
bijector_0 = get_bijector(reuse=tf.AUTO_REUSE)
y_0 = bijector_0.forward(x)
bijector_1 = get_bijector(reuse=tf.AUTO_REUSE)
y_1 = bijector_1.forward(x)
# We were expecting that the `y_0` and `y_1` share the same dependent variables,
# since we used `tf.AUTO_REUSE` within the `tf.variable_scope`. However, the following
# will return a `False`.
print(get_dependent_variables(y_0) == get_dependent_variables(y_1))
其中我们使用了获得张量所依赖的所有变量的函数:
import collections
def get_dependent_variables(tensor):
"""Returns all variables that the tensor `tensor` depends on.
Forked from: https://stackoverflow.com/a/42861919/1218716
Args:
tensor: Tensor.
Returns:
List of variables.
"""
# Initialize
starting_op = tensor.op
dependent_vars = []
queue = collections.deque()
queue.append(starting_op)
op_to_var = {var.op: var for var in tf.trainable_variables()}
visited = {starting_op}
while queue:
op = queue.popleft()
try:
dependent_vars.append(op_to_var[op])
except KeyError:
# `op` is not a variable, so search its inputs (if any).
for op_input in op.inputs:
if op_input.op not in visited:
queue.append(op_input.op)
visited.add(op_input.op)
return dependent_vars
解决方案
推荐阅读
- android - 如何计算并知道从当前日期和时间开始的 24 小时
- java - 为 AsyncTask 实现一个常规的 JSON 解析器函数
- angular - 在“元素”上执行“请求全屏”失败时出现错误:API 只能由用户手势启动
- xslt - TransformerFactoryImpl 类加载问题
- ios - 在 swift 4 中验证电话号码时,我从 firebase 收到“无效令牌”
- angular - 从api获取数据时,Angular 6 Chartist Donut Chart第一次没有加载
- r - 如何在闪亮的应用程序中根据标称特征的级别为饼图着色?
- javascript - 如何将此数组结果值存储在数据库中?
- angularjs - 在 ng-repeat AngularJS 中定位特定元素
- javascript - 在 ScrollView onscroll 属性中调用函数