如何在 Tensorflow 中使用 SWA 实现 Batch Norm? [英] How implement Batch Norm with SWA in Tensorflow?

查看:100
本文介绍了如何在 Tensorflow 中使用 SWA 实现 Batch Norm?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在 Tensorflow 2.2 中使用随机权重平均 (SWA) 和批量归一化层.对于 Batch Norm,我使用 tf.keras.layers.BatchNormalization.对于 SWA,我使用我自己的代码来平均权重(我在 tfa.optimizers.SWA 出现之前编写了我的代码).我在多个来源中读到,如果使用批处理规范和 SWA,我们必须运行前向传递以使某些数据(激活权重和/或动量值的运行平均值和标准差?)可用于批处理规范层.我不明白 - 尽管阅读了很多 - 正是需要做的事情以及如何做.具体:

I am using Stochastic Weight Averaging (SWA) with Batch Normalization layers in Tensorflow 2.2. For Batch Norm I use tf.keras.layers.BatchNormalization. For SWA I use my own code to average the weights (I wrote my code before tfa.optimizers.SWA appeared). I have read in multiple sources that if using batch norm and SWA we must run a forward pass to make certain data (running mean and st dev of activation weights and/or momentum values?) available to the batch norm layers. What I do not understand - despite a lot of reading - is exactly what needs to be done and how. Specifically:

  1. 什么时候必须运行前向/预测传递?在每个结束时小批量、每个 epoch 结束、所有训练结束?
  2. 当正向传球运行时,运行平均值如何?stdev 值可用到批处理规范层?
  3. 这个过程是由 tfa.optimizers.SWA 类神奇地执行的吗?
  1. When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?
  2. When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?
  3. Is this process performed magically by the tfa.optimizers.SWA class?

推荐答案

什么时候必须运行前向/预测传递?在每个结束时mini-batch,每个epoch结束,所有训练结束?

When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?

在训练结束时.可以这样想,SWA 是通过将您的最终权重与运行平均值交换来执行的.但是所有的批规范层仍然是根据旧权重的统计数据计算的.所以我们需要向前传球让他们赶上.

At the end of training. Think of it like this, SWA is performed by swapping your final weights with a running average. But all batch norm layers are still calculated based on statistics from your old weights. So we need to run a forward pass to let them catch up.

当前传运行时,运行平均值如何?标准差值可用于批处理规范层?

When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?

在正常的前向传递(预测)期间,不会更新运行均值和标准差.所以我们实际上需要做的是训练网络,而不是更新权重.这就是论文所说的在训练模式"下运行前向传递时所指的内容.

During a normal forward pass (prediction) the running mean and standard deviation will not be updated. So what we actually need to do is to train the network, but not update the weights. This is what the paper refers to when it says to run the forward pass in "training mode".

实现这一目标的最简单方法(我知道)是重置批量归一化层并训练一个额外的 epoch,并将学习率设置为 0.

The easiest way to achieve this (that I know) is to reset the batch normalization layers and train one additional epoch with learning rate set to 0.

这个过程是不是由 tfa.optimizers.SWA 类神奇地执行了?

Is this process performed magically by the tfa.optimizers.SWA class?

我不知道.但是,如果您使用的是 Tensorflow Keras,那么我已经制作了这个 Keras SWA 回调 来做到这一点就像论文中提到的学习率表一样.

I don't know. But if you are using Tensorflow Keras then I have made this Keras SWA callback that does it like in the paper including the learning rate schedules.

这篇关于如何在 Tensorflow 中使用 SWA 实现 Batch Norm?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆