首页 > 解决方案 > graph.get_tensor_by_name 和 tf.global_variable 之间的区别

问题描述

我可以通过 graph.get_tensor_by_name 获得一个张量,但是我在 tf.global_variable 中找不到它。就我而言,我定义了一些 tf.Tensor 如下:

output_y = Dense(units=y.shape[1],activation='softmax',kernel_regularizer=regularizers.l2(),bias_regularizer=regularizers.l2(),activity_regularizer=regularizers.l2(),name='output_y_'+str(index))(pretrain_output)
y_tf = tf.placeholder(tf.float32, shape=(None, y.shape[1]),name='y_tf_'+str(index))
loss_tensor = tf.nn.softmax_cross_entropy_with_logits(logits=output_y, labels=y_tf, name='loss_tensor_' + str(index))

我可以按如下方式导出张量形状和名称:

>>output_y
<tf.Tensor 'train_variable/output_y_0/Softmax:0' shape=(?, 4) dtype=float32>
>>y_tf
<tf.Tensor 'train_variable/y_tf_0:0' shape=(?, 4) dtype=float32>
>>loss_tensor
<tf.Tensor 'train_variable/loss_tensor_0/Reshape_2:0' shape=(?,) dtype=float32>

另外,我可以使用 tf.get_default_graph.get_tensor_by_name 来检索张量:

>>tf.get_default_graph().get_tensor_by_name('train_variable/output_y_0/Softmax:0')
<tf.Tensor 'train_variable/output_y_0/Softmax:0' shape=(?, 4) dtype=float32>
>>tf.get_default_graph().get_tensor_by_name('train_variable/y_tf_0:0')
<tf.Tensor 'train_variable/y_tf_0:0' shape=(?, 4) dtype=float32>
>>tf.get_default_graph().get_tensor_by_name('train_variable/loss_tensor_0/Reshape_2:0')
<tf.Tensor 'train_variable/loss_tensor_0/Reshape_2:0' shape=(?,) dtype=float32>

但是,在 tf.global_variables() 中找不到这些变量名。似乎 tf.global_variables() 只包含内核/偏差等参数变量。现在我必须记住张量名称才能检索对象输出(在我的情况下为 output_y)。有人可以告诉我如何检索张量,例如在包含所有张量的列表中搜索它吗?

标签: tensorflowmachine-learning

解决方案


来自节点的读取操作的张量和作为变量的张量之间存在差异。

一个变量由一个值和几个操作组成:

import tensorflow as tf
a = tf.get_variable('a', tf.float32)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

sess.run(a)  # gives 42.
sess.run(tf.get_default_graph().get_tensor_by_name('a/read:0'))  # gives 42. as well
print(a.op.outputs)  # <tf.Tensor 'a:0' shape=() dtype=float32_ref>]

它的行为类似:

>>> type(a)
<class 'tensorflow.python.ops.variables.Variable'>
>>> type(tf.get_default_graph().get_tensor_by_name('a/read:0'))
<class 'tensorflow.python.framework.ops.Tensor'>

但它们是不同的。

最简单的方法是返回output_y以防您再次需要它。否则请按照: https ://stackoverflow.com/a/36893840/7443104


推荐阅读