哪些PyTorch模块受model.eval()和model.train()影响? [英] Which PyTorch modules are affected by model.eval() and model.train()?
本文介绍了哪些PyTorch模块受model.eval()和model.train()影响?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
model.eval()
方法修改某些模块(层),这些模块在训练和推理过程中必须表现出不同的行为.文档:
The model.eval()
method modifies certain modules (layers) which are required to behave differently during training and inference. Some examples are listed in the docs:
这仅对某些模块有[an]作用.请参阅特定模块的文档,以了解其在培训/评估模式下的行为的详细信息(如果受到影响),例如
Dropout
,BatchNorm
等
是否列出了受影响的模块的详尽列表?
Is there an exhaustive list of which modules are affected?
推荐答案
除了 @iacob 提供的信息之外:
基类 | 模块 | 条件 |
---|---|---|
其他规范化层 除LocalResponseNorm | GroupNorm LayerNorm | track_running_stats=True |
RNNBase | RNN LSTM GRU | 辍学>0 (默认:0 ) |
变形金刚层 | Transformer TransformerEncoder TransformerDecoder | 辍学>0 ( Transformer 默认值: 0.1 )如果 norm 设置为归一化层 |
惰性变体 | LazyBatchNorm 当前每晚 合并的PR | track_running_stats = True |
Base class | Module | Criteria |
---|---|---|
Other normalization layers EXCEPT LocalResponseNorm |
GroupNorm LayerNorm |
track_running_stats=True |
RNNBase |
RNN LSTM GRU |
dropout > 0 (default: 0 ) |
Transformer layers | Transformer TransformerEncoder TransformerDecoder |
dropout > 0 (Transformer default: 0.1 ) if norm set to normalization layer |
Lazy variants | LazyBatchNorm currently nightly merged PR |
track_running_stats=True |
这篇关于哪些PyTorch模块受model.eval()和model.train()影响?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文