张量流中的批量标准化 [英] Batch Normalization in tensorflow

查看:48
本文介绍了张量流中的批量标准化的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我注意到 tensorflow 的 api 中已经有批量标准化函数.但是我不明白的一件事是如何更改训练和测试之间的程序?

I noticed there are batch normalization functions already in the api for tensorflow. One thing I don't understand though, is how to to change the procedure between training and test?

批量归一化在测试和训练期间的作用不同.特别是在训练期间使用固定的均值和方差.

Batch normalization acts differently during test than during training. Specifically one uses a fixed mean and variance during training.

某处有一些很好的示例代码吗?我看到了一些,但是对于范围变量,它变得令人困惑

Is there some good example code somewhere? I saw some, but with scope variables it got confusing

推荐答案

您说得对,tf.nn.batch_normalization 仅提供了实现批量标准化的基本功能.您必须添加额外的逻辑以在训练期间跟踪移动均值和方差,并在推理期间使用经过训练的均值和方差.您可以查看此示例以获得非常通用的实现,但是这里有一个不使用 gamma 的快速版本:

You are right, the tf.nn.batch_normalization provides just the basic functionality for implementing batch normalization. You have to add the extra logic to keep track of moving means and variances during training, and use the trained means and variances during inference. You can look at this example for a very general implementation, but a quick version that doesn't use gamma is here :

  beta = tf.Variable(tf.zeros(shape), name='beta')
  moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                 trainable=False)
  moving_variance = tf.Variable(tf.ones(shape),
                                     name='moving_variance',
                                     trainable=False)
  control_inputs = []
  if is_training:
    mean, variance = tf.nn.moments(image, [0, 1, 2])
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, self.decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, self.decay)
    control_inputs = [update_moving_mean, update_moving_variance]
  else:
    mean = moving_mean
    variance = moving_variance
  with tf.control_dependencies(control_inputs):
    return tf.nn.batch_normalization(
        image, mean=mean, variance=variance, offset=beta,
        scale=None, variance_epsilon=0.001)

这篇关于张量流中的批量标准化的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆