Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别? [英] Difference between Keras' BatchNormalization and PyTorch's BatchNorm2d?

查看:18
本文介绍了Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个在 Keras 和 PyTorch 中实现的小型 CNN 示例.当我打印两个网络的摘要时,可训练参数的总数相同但参数总数和批量标准化的参数数量不匹配.

I've a sample tiny CNN implemented in both Keras and PyTorch. When I print summary of both the networks, the total number of trainable parameters are same but total number of parameters and number of parameters for Batch Normalization don't match.

这是在 Keras 中的 CNN 实现:

Here is the CNN implementation in Keras:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

上面模型打印的摘要是:

And the summary printed for above model is:

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 131072)            0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               13107300  
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1010      
=================================================================
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64
_________________________________________________________________

以下是 PyTorch 中相同模型架构的实现:

Here's the implementation of the same model architecture in PyTorch:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
        nn.ReLU(True),
        nn.BatchNorm2d(num_features=32),
    )
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

以下是上述模型的总结输出:

And following is the output of summary of the above model:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
================================================================
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02
----------------------------------------------------------------

正如您在上面的结果中看到的,Keras 中的 Batch Normalization 比 PyTorch 具有更多的参数(准确地说是 2 倍).那么上面的CNN架构有什么不同呢?如果它们是等价的,那么我在这里缺少什么?

As you can see in above results, Batch Normalization in Keras has more number of parameters than PyTorch (2x to be exact). So what's the difference in above CNN architectures? If they are equivalent, then what am I missing here?

推荐答案

Keras 将许多将在层中保存/加载"的内容视为参数(权重).

Keras treats as parameters (weights) many things that will be "saved/loaded" in the layer.

虽然这两种实现自然具有批次的累积均值"和方差",但这些值无法通过反向传播进行训练.

While both implementations naturally have the accumulated "mean" and "variance" of the batches, these values are not trainable with backpropagation.

尽管如此,这些值每批次都会更新,Keras 将它们视为不可训练的权重,而 PyTorch 只是将它们隐藏起来.此处的术语不可训练"表示不可通过反向传播训练",但并不意味着值被冻结.

Nevertheless, these values are updated every batch, and Keras treats them as non-trainable weights, while PyTorch simply hides them. The term "non-trainable" here means "not trainable by backpropagation", but doesn't mean the values are frozen.

对于 BatchNormalization 层,它们总共有 4 组权重".考虑选定的轴(默认 = -1,层的大小 = 32)

In total they are 4 groups of "weights" for a BatchNormalization layer. Considering the selected axis (default = -1, size=32 for your layer)

  • scale (32) - 可训练
  • offset (32) - 可训练
  • accumulated mean (32) - 不可训练,但每批次更新
  • accumulated std (32) - 不可训练,但每批次更新
  • scale (32) - trainable
  • offset (32) - trainable
  • accumulated means (32) - non-trainable, but updated every batch
  • accumulated std (32) - non-trainable, but updated every batch

在 Keras 中拥有这样的好处是,当您保存图层时,您还可以像自动保存图层中的所有其他权重一样保存均值和方差值.当您加载图层时,这些权重会一起加载.

The advantage of having it like this in Keras is that when you save the layer, you also save the mean and variance values the same way you save all other weights in the layer automatically. And when you load the layer, these weights are loaded together.

这篇关于Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆