在Keras批处理规范中禁用移动平均值和标准差 [英] Disable moving mean and std in Keras batch norm

查看:84
本文介绍了在Keras批处理规范中禁用移动平均值和标准差的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

是否有一种方法可以使Keras在每个BatchNormalization层中的每个批次上分别计算均值和标准差,而不是对其进行训练(类似于PyTorch的方式)?

要详细说明,Keras BatchNormalization层将保存四组权重[gamma, beta, mean, std],并在每个批次上进行更新.

PyTorch BatchNorm2d拥有两组权重[gamma,beta],并为每批计算均值和标准差:如何设置权重批处理规范化层的内容? 但是它没有像我描述的那样设置功能(类似于PyTorch).

解决方案

Keras确实"在训练期间分别计算每批的均值和标准差.

这些值没有像其他两个一样训练.它们只是累积"的. (累积值仅用于预测/推断)

这就是为什么BatchNormalization层的一半参数在模型摘要中显示为不可训练的参数"的原因. (因此,当假设PyTorch不包含这些时-请当心-我不知道PyTorch层,如果您知道它深"忽略了这些括号-由于它们不能通过反向传播训练",也许PyTorch实际上拥有它们,但不将它们视为权重".Keras称它们为权重",因为它们将在您保存/加载模型时被保存/加载,但通常来说它们肯定是不可训练的") /p>

为什么要累积这些统计数据很重要?因为当您使用模型进行预测(而不是用于训练)时,如果输入数据不是很大,或者如果输入数据的均值和标准差与训练模型所用的均值和标准差很大,则可能会得到错误的结果.我相信,在实际情况下,要预测一个项目而不是一个大批次并不少见.

但是,如果您仍然希望删除它们,则可以简单地从源代码中复制代码,并删除与moving mean和std有关的所有内容:

BatchNormalization

 class BatchNormalization2(Layer):
    """
    Commented all things regarding moving statistics
    """

    @interfaces.legacy_batchnorm_support
    def __init__(self,
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 #moving_mean_initializer='zeros',
                 #moving_variance_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(BatchNormalization2, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        #self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        #self.moving_variance_initializer = (initializers.get(moving_variance_initializer))
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        dim = input_shape[self.axis]
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')
        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        #self.moving_mean = self.add_weight(
        #    shape=shape,
        #    name='moving_mean',
        #    initializer=self.moving_mean_initializer,
        #    trainable=False)
        #self.moving_variance = self.add_weight(
        #    shape=shape,
        #    name='moving_variance',
        #    initializer=self.moving_variance_initializer,
        #    trainable=False)

        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        #def normalize_inference():
        #    if needs_broadcasting:
        #        # In this case we must explicitly broadcast all parameters.
        #        broadcast_moving_mean = K.reshape(self.moving_mean,
        #                                          broadcast_shape)
        #        broadcast_moving_variance = K.reshape(self.moving_variance,
        #                                              broadcast_shape)
        #        if self.center:
        #            broadcast_beta = K.reshape(self.beta, broadcast_shape)
        #        else:
        #            broadcast_beta = None
        #        if self.scale:
        #            broadcast_gamma = K.reshape(self.gamma,
        #                                        broadcast_shape)
        #        else:
        #            broadcast_gamma = None
        #        return K.batch_normalization(
        #            inputs,
        #            broadcast_moving_mean,
        #            broadcast_moving_variance,
        #            broadcast_beta,
        #            broadcast_gamma,
        #            axis=self.axis,
        #            epsilon=self.epsilon)
        #    else:
        #        return K.batch_normalization(
        #            inputs,
        #            self.moving_mean,
        #            self.moving_variance,
        #            self.beta,
        #            self.gamma,
        #            axis=self.axis,
        #            epsilon=self.epsilon)

        ## If the learning phase is *static* and set to inference:
        #if training in {0, False}:
        #    return normalize_inference()

        # --- for SO answer, ignore the following original comment
        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs, self.gamma, self.beta, reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod([K.shape(inputs)[axis]
                                  for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        #self.add_update([K.moving_average_update(self.moving_mean,
        #                                         mean,
        #                                         self.momentum),
        #                 K.moving_average_update(self.moving_variance,
        #                                         variance,
        #                                         self.momentum)],
        #                inputs)

        ## Pick the normalized form corresponding to the training phase.
        #return K.in_train_phase(normed_training,
        #                        normalize_inference,
        #                        training=training)
        return normed_training

    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            #'moving_mean_initializer':
            #    initializers.serialize(self.moving_mean_initializer),
            #'moving_variance_initializer':
            #    initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(BatchNormalization2, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
 

Is there a way to make Keras calculate mean and std separately on each batch in every BatchNormalization layer instead of training it (similar to how PyTorch does it)?

To elaborate more, Keras BatchNormalization layer holds four sets of weights [gamma, beta, mean, std] and updates them on every batch.

While PyTorch BatchNorm2d holds two sets of weights [gamma, beta] and the mean and std are calculated for every batch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d

I found a similar question here: How to set weights of the batch normalization layer? but it does not set the functionality as I described (similar to PyTorch).

解决方案

Keras "does" calculate the mean and std separately on each batch during training.

These values are not trained as the other two. They are just "accumulated". (The accumulated values are used only in prediction/inference)

That's why half of the parameters for a BatchNormalization layer appear as "not-trainable parameters" in the model summary. (So, be careful when assuming that PyTorch doesn't have these -- I don't know the PyTorch layer, if you know it "deeply" ignore these parentheses -- Since they are not "trainable" by backpropagation, maybe PyTorch actually has them, but doesn't treat them as "weights". Keras puts them as "weights" because they will be saved/loaded when you save/load the model, but they are certainly "not trainable" in the usual sense)

Why is it important to have these statistics accumulated? Because when you use your model for prediction (not for training), if your input data is not a big batch, or if your input data has mean and std very different from what the model was trained with, you may get wrong results. I believe it's not uncommon to want to predict a single item instead of a big batch in real situations.

But, if you want to remove them anyway, you can simply copy the code from the source and remove everything regarding the moving mean and std:

Source code for BatchNormalization

class BatchNormalization2(Layer):
    """
    Commented all things regarding moving statistics
    """

    @interfaces.legacy_batchnorm_support
    def __init__(self,
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 #moving_mean_initializer='zeros',
                 #moving_variance_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(BatchNormalization2, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        #self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        #self.moving_variance_initializer = (initializers.get(moving_variance_initializer))
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        dim = input_shape[self.axis]
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')
        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        #self.moving_mean = self.add_weight(
        #    shape=shape,
        #    name='moving_mean',
        #    initializer=self.moving_mean_initializer,
        #    trainable=False)
        #self.moving_variance = self.add_weight(
        #    shape=shape,
        #    name='moving_variance',
        #    initializer=self.moving_variance_initializer,
        #    trainable=False)

        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        #def normalize_inference():
        #    if needs_broadcasting:
        #        # In this case we must explicitly broadcast all parameters.
        #        broadcast_moving_mean = K.reshape(self.moving_mean,
        #                                          broadcast_shape)
        #        broadcast_moving_variance = K.reshape(self.moving_variance,
        #                                              broadcast_shape)
        #        if self.center:
        #            broadcast_beta = K.reshape(self.beta, broadcast_shape)
        #        else:
        #            broadcast_beta = None
        #        if self.scale:
        #            broadcast_gamma = K.reshape(self.gamma,
        #                                        broadcast_shape)
        #        else:
        #            broadcast_gamma = None
        #        return K.batch_normalization(
        #            inputs,
        #            broadcast_moving_mean,
        #            broadcast_moving_variance,
        #            broadcast_beta,
        #            broadcast_gamma,
        #            axis=self.axis,
        #            epsilon=self.epsilon)
        #    else:
        #        return K.batch_normalization(
        #            inputs,
        #            self.moving_mean,
        #            self.moving_variance,
        #            self.beta,
        #            self.gamma,
        #            axis=self.axis,
        #            epsilon=self.epsilon)

        ## If the learning phase is *static* and set to inference:
        #if training in {0, False}:
        #    return normalize_inference()

        # --- for SO answer, ignore the following original comment
        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs, self.gamma, self.beta, reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod([K.shape(inputs)[axis]
                                  for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        #self.add_update([K.moving_average_update(self.moving_mean,
        #                                         mean,
        #                                         self.momentum),
        #                 K.moving_average_update(self.moving_variance,
        #                                         variance,
        #                                         self.momentum)],
        #                inputs)

        ## Pick the normalized form corresponding to the training phase.
        #return K.in_train_phase(normed_training,
        #                        normalize_inference,
        #                        training=training)
        return normed_training

    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            #'moving_mean_initializer':
            #    initializers.serialize(self.moving_mean_initializer),
            #'moving_variance_initializer':
            #    initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(BatchNormalization2, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

这篇关于在Keras批处理规范中禁用移动平均值和标准差的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆