tensorflow - 如何在不考虑方差的情况下去除 Keras 层的均值,例如 Batchnormalization?
问题描述
我想做 Keras 中 BatchNormalization 层所做的事情,即移除平均值并存储移动平均值。不幸的是,Keras 中的BatchNormalization 层也总是考虑方差,我不想使用它。
我正在考虑使用 Average 和 Subtract 层,但是当训练结束时它们不会存储任何东西以供使用。这个想法是我的层移除并学习平均值,所以在测试预测时,它减去一个常数值。
解决方案
我创建了一个层来执行此操作,从codeCentering
复制。它使用动量来移动当前的移动平均值。它似乎有效,我可以用它保存和加载模型。BatchNormalization
from tensorflow.keras import backend
from tensorflow.keras import initializers
from tensorflow.keras import layers
from tensorflow import math
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
class Centering(layers.Layer):
"""Layer that centers the data learning a mean."""
def __init__(self, momentum=0.01, **kwargs):
"""Constructor of LatentProjection."""
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super().__init__(**kwargs)
self.input_spec = layers.InputSpec(min_ndim=2)
self.momentum = momentum
self.moving_mean = None
def build(self, input_shape):
"""Create internal variables."""
assert len(input_shape) >= 2
input_dim = input_shape[-1]
self.moving_mean = self.add_weight(
name='moving_mean',
shape=(input_dim,),
initializer=initializers.Zeros,
synchronization=variables.VariableSynchronization.ON_READ,
trainable=False,
aggregation=variables.VariableAggregation.MEAN,
experimental_autocast=False)
self.input_spec = layers.InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
def _get_training_value(self, training=None):
"""Copied from normalization.py."""
if training is None:
training = backend.learning_phase()
if isinstance(training, int):
training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value
# passed from model.
training = False
return training
def _support_zero_size_input(self):
"""Copied from normalization.py."""
return distribution_strategy_context.has_strategy() and getattr(
distribution_strategy_context.get_strategy().extended,
'experimental_enable_get_next_as_optional', False)
def _assign_moving_average(self, variable, value, momentum, inputs_size):
"""Copied from normalization.py."""
with backend.name_scope('AssignMovingAvg') as scope:
with ops.colocate_with(variable):
decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = (variable -
math_ops.cast(value, variable.dtype)) * decay
if inputs_size is not None:
update_delta = array_ops.where(
inputs_size > 0, update_delta,
backend.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope)
def call(self, inputs, training=None, **kwargs):
"""Called for each mini batch when applied to input layer."""
training = self._get_training_value(training)
training_value = tf_utils.constant_value(training)
if training_value == False:
mean = self.moving_mean
else:
mean = math.reduce_mean(inputs, axis=0)
# Following code copied from normalization.py to update moving mean
if self._support_zero_size_input():
# Keras assumes that batch dimension is the first dimension for
# Batch Normalization.
input_batch_size = array_ops.shape(inputs)[0]
else:
input_batch_size = None
def mean_update():
"""Perform update of moving mean average using copied code."""
self._assign_moving_average(
self.moving_mean, mean, self.momentum, input_batch_size)
self.add_update(mean_update)
# Center inputs
return inputs - mean
def get_config(self):
"""Internal config of this layer."""
config = {
'momentum': self.momentum,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
推荐阅读
- reactjs - React - 获取多个 api
- c# - 在 C# 中拆分包含英语和希伯来语的字符串
- python - 每隔一秒出现一次未知元素后拆分字符串
- html - Azure 应用服务 - 访问 PWA webmanifest.json 时出现 CORS OAUTH2 错误
- c++ - 如何使用 opencv c++ 和 android studio ndk 将图像读入 Mat 对象?
- javascript - 服务器和客户端之间的双向通信
- json.net - 从 MongoDB.Driver 迁移到 Azure DocumentClient 时从 CosmosDB 反序列化的问题
- swift - 快速在本地保存 MTLTexture 数据
- c# - 全球资源翻译不起作用
- java - XML 到 JAVA (JAXB) 错误 - “SYSTEM”和系统标识符之间需要空格