在PyTorch中保存经过训练的模型的最佳方法? [英] Best way to save a trained model in PyTorch?

查看:1028
本文介绍了在PyTorch中保存经过训练的模型的最佳方法?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在寻找其他方法来在PyTorch中保存经过训练的模型。到目前为止,我已经找到了两种选择。

I was looking for alternative ways to save a trained model in PyTorch. So far, I have found two alternatives.


  1. torch.save()保存模型并 torch.load()来加载模型。

  2. model.state_dict()以保存训练有素的模型和 model.load_state_dict()来加载保存的模型。

  1. torch.save() to save a model and torch.load() to load a model.
  2. model.state_dict() to save a trained model and model.load_state_dict() to load the saved model.

我遇到了这个讨论其中,建议方法2优于方法1。

I have come across to this discussion where approach 2 is recommended over approach 1.

我的问题是,为什么第二种方法是首选?仅仅是因为 torch.nn 模块具有这两个功能,我们鼓励使用它们吗? / p>

My question is, why the second approach is preferred? Is it only because torch.nn modules have those two function and we are encouraged to use them?

推荐答案

我发现此页面在他们的github存储库上,我只将内容粘贴在这里。

I've found this page on their github repo, I'll just paste the content here.

序列化和还原模型有两种主要方法。

There are two main approaches for serializing and restoring a model.

第一个(推荐)保存并仅加载模型参数:

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

然后,然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

The second saves and loads the entire model:

torch.save(the_model, PATH)

然后稍后:

the_model = torch.load(PATH)

但是,在这种情况下,序列化的数据绑定到特定的类
和使用的确切目录结构,因此可以在其他项目中使用
时或经过一些严重的重构后,它会以各种方式中断。

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

这篇关于在PyTorch中保存经过训练的模型的最佳方法?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆