PyTorch 中 register_parameter 和 register_buffer 有什么区别? [英] What is the difference between register_parameter and register_buffer in PyTorch?

查看:197
本文介绍了PyTorch 中 register_parameter 和 register_buffer 有什么区别?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

模块的参数在训练过程中得到改变,也就是说,它们是在神经网络训练过程中学到的东西,但是什么是 缓冲区?

Module's parameters get changed during training, that is, they are what is learnt during training of a neural network, but what is a buffer?

它是在神经网络训练中学到的吗?

and is it learnt during neural network training?

推荐答案

Pytorch doc 用于 register_buffer() 方法读取

Pytorch doc for register_buffer() method reads

这通常用于注册不应被视为模型参数的缓冲区.例如,BatchNorm 的 running_mean 不是参数,而是持久状态的一部分.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the persistent state.

如您所见,模型参数是在训练过程中使用 SGD 学习和更新的.
但是,有时还有其他数量是模型状态"的一部分,应该
- 保存为 state_dict 的一部分.
- 与模型的其余参数一起移至 cuda()cpu().
- 使用模型的其余参数转换为 float/half/double.
将这些参数"注册为模型的 buffer 允许 pytorch 跟踪它们并像常规参数一样保存它们,但会阻止 pytorch 使用 SGD 机制更新它们.

As you already observed, model parameters are learned and updated using SGD during the training process.
However, sometimes there are other quantities that are part of a model's "state" and should be
- saved as part of state_dict.
- moved to cuda() or cpu() with the rest of the model's parameters.
- cast to float/half/double with the rest of the model's parameters.
Registering these "arguments" as the model's buffer allows pytorch to track them and save them like regular parameters, but prevents pytorch from updating them using SGD mechanism.

可以在 _BatchNorm 模块,其中 running_meanrunning_varnum_batches_tracked 被注册为缓冲区并通过累积更新通过该层转发的数据的统计.这与使用常规 SGD 优化学习数据仿射变换的 weightbias 参数形成对比.

An example for a buffer can be found in _BatchNorm module where the running_mean , running_var and num_batches_tracked are registered as buffers and updated by accumulating statistics of data forwarded through the layer. This is in contrast to weight and bias parameters that learns an affine transformation of the data using regular SGD optimization.

这篇关于PyTorch 中 register_parameter 和 register_buffer 有什么区别?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆