首页 > 解决方案 > 在 TensorFlow 中,如何查看批量标准化参数?

问题描述

tf.layers.batch_normalization在我的网络中使用了一个层。您可能知道,批量归一化对该层中的每个单元 u_i 使用可训练参数 gamma 和 beta,为各种输入 x 选择其自己的标准差和 u_i(x) 的均值。通常,gamma 初始化为 1,beta 初始化为 0。

我有兴趣查看在各个单元中学习的 gamma 和 beta 值,以收集有关它们在网络训练后趋向于结束的位置的统计数据。如何在每个训练实例中查看它们的当前值?

标签: tensorflowmachine-learningneural-networkpython-3.6

解决方案


您可以获取批处理规范化层范围内的所有变量并打印它们。例子:

import tensorflow as tf

tf.reset_default_graph()
x = tf.constant(3.0, shape=(3,))
x = tf.layers.batch_normalization(x)

print(x.name) # batch_normalization/batchnorm/add_1:0

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope='batch_normalization')
print(variables)

#[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32_ref>,
#  <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32_ref>]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    gamma = sess.run(variables[0])
    print(gamma) # [1. 1. 1.]

推荐阅读