pytorch中的model.eval()有什么作用? [英] What does model.eval() do in pytorch?
问题描述
我正在使用此代码,并在某些情况下看到了 model.eval()
.
I am using this code, and saw model.eval()
in some cases.
我知道应该允许我评估我的模型",但是我不知道什么时候应该使用或不应该使用它,或者如何将其关闭.
I understand it is supposed to allow me to "evaluate my model", but I don't understand when I should and shouldn't use it, or how to turn if off.
我想运行上面的代码来训练网络,并且还能够在每个时期运行验证.我仍然无法做到这一点.
I would like to run the above code to train the network, and also be able to run validation every epoch. I wasn't able to do it still.
推荐答案
model.eval()
是模型的某些特定层/部分的一种开关,在训练和推理期间它们的行为不同(评估)时间.例如,Dropouts层,BatchNorm层等.您需要在模型评估期间将其关闭,然后 .eval()
会为您完成此操作.此外,评估/验证的常用做法是将 torch.no_grad()
与 model.eval()
配对使用以关闭梯度计算:
model.eval()
is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc. You need to turn off them during model evaluation, and .eval()
will do it for you. In addition, the common practice for evaluating/validation is using torch.no_grad()
in pair with model.eval()
to turn off gradients computation:
# evaluate model:
model.eval()
with torch.no_grad():
...
out_data = model(data)
...
但是,别忘了在评估步骤之后回到 training
模式:
BUT, don't forget to turn back to training
mode after eval step:
# training step
...
model.train()
...
这篇关于pytorch中的model.eval()有什么作用?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!