首页 > 解决方案 > 如何在不考虑方差的情况下去除 Keras 层的均值,例如 Batchnormalization?

问题描述

我想做 Keras 中 BatchNormalization 层所做的事情,即移除平均值并存储移动平均值。不幸的是,Keras 中的BatchNormalization 层也总是考虑方差,我不想使用它。

我正在考虑使用 Average 和 Subtract 层,但是当训练结束时它们不会存储任何东西以供使用。这个想法是我的层移除并学习平均值,所以在测试预测时,它减去一个常数值。

标签: tensorflowkerasnormalizationcenteringbatch-normalization

解决方案


我创建了一个层来执行此操作,从codeCentering复制。它使用动量来移动当前的移动平均值。它似乎有效,我可以用它保存和加载模型。BatchNormalization

from tensorflow.keras import backend
from tensorflow.keras import initializers
from tensorflow.keras import layers
from tensorflow import math
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables


class Centering(layers.Layer):
    """Layer that centers the data learning a mean."""

    def __init__(self, momentum=0.01, **kwargs):
        """Constructor of LatentProjection."""
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super().__init__(**kwargs)
        self.input_spec = layers.InputSpec(min_ndim=2)
        self.momentum = momentum
        self.moving_mean = None

    def build(self, input_shape):
        """Create internal variables."""
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]
        self.moving_mean = self.add_weight(
            name='moving_mean',
            shape=(input_dim,),
            initializer=initializers.Zeros,
            synchronization=variables.VariableSynchronization.ON_READ,
            trainable=False,
            aggregation=variables.VariableAggregation.MEAN,
            experimental_autocast=False)
        self.input_spec = layers.InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True

    def _get_training_value(self, training=None):
        """Copied from normalization.py."""
        if training is None:
            training = backend.learning_phase()
        if isinstance(training, int):
            training = bool(training)
        if not self.trainable:
            # When the layer is not trainable, it overrides the value
            # passed from model.
            training = False
        return training

    def _support_zero_size_input(self):
        """Copied from normalization.py."""
        return distribution_strategy_context.has_strategy() and getattr(
            distribution_strategy_context.get_strategy().extended,
            'experimental_enable_get_next_as_optional', False)

    def _assign_moving_average(self, variable, value, momentum, inputs_size):
        """Copied from normalization.py."""
        with backend.name_scope('AssignMovingAvg') as scope:
            with ops.colocate_with(variable):
                decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                if inputs_size is not None:
                    update_delta = array_ops.where(
                        inputs_size > 0, update_delta,
                        backend.zeros_like(update_delta))
                return state_ops.assign_sub(variable, update_delta, name=scope)

    def call(self, inputs, training=None, **kwargs):
        """Called for each mini batch when applied to input layer."""
        training = self._get_training_value(training)
        training_value = tf_utils.constant_value(training)
        if training_value == False:
            mean = self.moving_mean
        else:
            mean = math.reduce_mean(inputs, axis=0)
            # Following code copied from normalization.py to update moving mean
            if self._support_zero_size_input():
                # Keras assumes that batch dimension is the first dimension for
                # Batch Normalization.
                input_batch_size = array_ops.shape(inputs)[0]
            else:
                input_batch_size = None

            def mean_update():
                """Perform update of moving mean average using copied code."""
                self._assign_moving_average(
                    self.moving_mean, mean, self.momentum, input_batch_size)
            self.add_update(mean_update)
        # Center inputs
        return inputs - mean

    def get_config(self):
        """Internal config of this layer."""
        config = {
            'momentum': self.momentum,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

推荐阅读