tensorflow - 张量流中的批量标准化问题
问题描述
我无法理解 Tensorflow 中批量标准化的实现。为了说明,我创建了一个包含一个输入节点、一个隐藏节点和一个输出节点的简单网络,并以 1 个批次运行,批次大小为 2。我的输入 x 包含一个具有 2 个值的标量(即批次大小2),一组为 0,另一组为 1。
我运行了一个 epoch,并写出隐藏层的输出(批标准化前后)以及批范数移动均值、方差、伽马和贝塔。
这是我的代码:
import tensorflow as tf
import numpy as np
N_HIDDEN_1 = 1
N_INPUT= 1
N_OUTPUT = 1
###########################################################
# DEFINE THE Network
# Define placeholders for data that will be fed in during execution
x = tf.placeholder(tf.float32, (None, N_INPUT))
y = tf.placeholder(tf.float32, (None, N_OUTPUT))
lx = tf.placeholder(tf.float32, [])
training = tf.placeholder_with_default(False, shape=(), name='training')
# Hidden layers with relu activation
with tf.variable_scope('hidden1'):
hidden_1 = tf.layers.dense(x, N_HIDDEN_1, activation=None, use_bias=False)
bn_1 = tf.layers.batch_normalization(hidden_1, training=training, momentum=0.5)
bn_1x = tf.nn.relu(bn_1)
# Output layer
with tf.variable_scope('output'):
predx = tf.layers.dense(bn_1x, N_OUTPUT, activation=None, use_bias=False)
pred = tf.layers.batch_normalization(predx, training=training, momentum=0.5)
###########################################################
# Define the cost function that is optimized when
# training the network and the optimizer
cost = tf.reduce_mean(tf.square(pred-y))
optimizer = tf.train.AdamOptimizer(learning_rate=lx).minimize(cost)
bout1 = tf.global_variables('hidden1/batch_normalization/moving_mean:0')
bout2 = tf.global_variables('hidden1/batch_normalization/moving_variance:0')
bout3 = tf.global_variables('hidden1/batch_normalization/gamma:0')
bout4 = tf.global_variables('hidden1/batch_normalization/beta:0')
###########################################################
# Train network
init = tf.global_variables_initializer()
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.Session() as sess:
sess.run(init)
# Create dummy data
batchx = np.zeros((2,1))
batchy = np.zeros((2,1))
batchx[0,0]=0.0
batchx[1,0]=1.0
batchy[0,0]=3.0
batchy[1,0]=4.0
_,_ = sess.run([optimizer, extra_update_ops], feed_dict={training: True, x:batchx, y:batchy, lx: 0.001})
print('weight of hidden layer')
W1 = np.array(sess.run(tf.global_variables('hidden1/dense/kernel:0')))
W1x = np.sum(W1, axis=1)
print(W1x)
print()
print('output from hidden layer, batch norm layer, and relu layer')
hid1,b1,b1x = sess.run([hidden_1, bn_1, bn_1x], feed_dict={training: False, x:batchx})
print('hidden_1', hid1)
print('bn_1', b1)
print('bn_1x', b1x)
print()
print('batchnorm parameters')
print('moving mean', sess.run(bout1))
print('moving variance', sess.run(bout2))
print('gamma', sess.run(bout3))
print('beta', sess.run(bout4))
这是我运行代码时得到的输出:
weight of hidden layer [[1.404974]]
output from hidden layer, batch norm layer, and relu layer
hidden_1 [[0. ]
[1.404974]]
bn_1 [[-0.40697935]
[ 1.215785 ]]
bn_1x [[0. ]
[1.215785]]
batchnorm parameters
moving mean [array([0.3514931], dtype=float32)]
moving variance [array([0.74709475], dtype=float32)]
gamma [array([0.999], dtype=float32)]
beta [array([-0.001], dtype=float32)]
我对生成的 batchnorm 参数感到困惑。在这种特殊情况下,隐藏层在应用批范数之前的输出是标量 0 和 1.404974。但批范数参数移动均值是 0.3514931。这是我使用动量 = 0.5 的情况。我不清楚为什么 1 次迭代后的移动平均值在这种情况下不完全是 0 和 1.404974 的平均值。我的印象是动量参数只会从第二批开始。
任何帮助将非常感激。
解决方案
因为您运行了优化器,所以很难知道里面到底发生了什么:您正在打印的 hidden_1 值不是用于更新批量标准统计信息的值;它们是更新后的值。
无论如何,我真的没有看到这个问题:
Moving mean original value = 0.0
batch mean value = (1.404974 - 0.0) / 2.0 = ~0.7
Moving mean value = momentum * Moving mean original value + (1 - momentum) * batch mean value
= 0.0 * 0.5 + (1 - 0.5) * 0.7
= 0.35
Moving variance original value = 1.0
batch variance value = ~0.5
Moving variance value = momentum * Moving variance original value + (1 - momentum) * batch variance value
= 1.0 * 0.5 + (1.0 - 0.5) * 0.5
= 0.75
推荐阅读
- css - 将 svg 添加到内容属性
- flutter - 我必须在我的 Stack 小部件中添加溢出吗?
- python-3.x - pynput 只能处理来自特定键盘的输入吗?如果是,如何实现?
- c# - 如何通过文本框在datagridview中搜索
- ubuntu - 如何安装 CMake 在 find_package 中包含 catkin/ROS 的项目?
- java - 使用方面修改方法参数
- eclipse - 原因:java.lang.NoClassDefFoundError: Failed to link when running JUnit tests through Arquiliqn for a Maven project
- sql - 由于 SAS 中的交错结果,如何折叠重复行
- node.js - Nodejs 发现 Nodejs 的其他实例
- javascript - 从纬度/经度坐标获取 WOEID