tensorflow - 在 TensorFlow 中,如何查看批量标准化参数?
问题描述
我tf.layers.batch_normalization
在我的网络中使用了一个层。您可能知道,批量归一化对该层中的每个单元 u_i 使用可训练参数 gamma 和 beta,为各种输入 x 选择其自己的标准差和 u_i(x) 的均值。通常,gamma 初始化为 1,beta 初始化为 0。
我有兴趣查看在各个单元中学习的 gamma 和 beta 值,以收集有关它们在网络训练后趋向于结束的位置的统计数据。如何在每个训练实例中查看它们的当前值?
解决方案
您可以获取批处理规范化层范围内的所有变量并打印它们。例子:
import tensorflow as tf
tf.reset_default_graph()
x = tf.constant(3.0, shape=(3,))
x = tf.layers.batch_normalization(x)
print(x.name) # batch_normalization/batchnorm/add_1:0
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope='batch_normalization')
print(variables)
#[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32_ref>]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
gamma = sess.run(variables[0])
print(gamma) # [1. 1. 1.]
推荐阅读
- ruby - 使用 Dir.chdir 处理带有空格的路径名
- python - 在用户模型之外散列密码字段 - Django
- python - 如何在while循环python中从每个循环中收集信息
- javascript - 在Javascript中递归创建选择菜单
- python - 刷新网络应用程序会导致删除的文件重新出现(烧瓶)
- python - 数据框列的单独值计数
- javascript - 试图阻止 Yup 表单验证接受 -0 作为正整数
- pytest - 导入 pytest 后出现错误
- python - 如何在python中随机分配治疗组?
- pygame - Spyder 未检测到 Pygame 模块