在 PyTorch 中保存训练模型的最佳方法? [英] Best way to save a trained model in PyTorch?

查看:39
本文介绍了在 PyTorch 中保存训练模型的最佳方法?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在寻找在 PyTorch 中保存训练模型的替代方法.到目前为止,我找到了两个替代方案.

I was looking for alternative ways to save a trained model in PyTorch. So far, I have found two alternatives.

  1. torch.save() 保存模型和 torch.load() 加载模型.
  2. model.state_dict() 保存训练好的模型和 model.load_state_dict() 加载保存的模型.
  1. torch.save() to save a model and torch.load() to load a model.
  2. model.state_dict() to save a trained model and model.load_state_dict() to load the saved model.

我遇到过这个讨论 推荐方法 2 而非方法 1.

I have come across to this discussion where approach 2 is recommended over approach 1.

我的问题是,为什么首选第二种方法?是否仅仅是因为 torch.nn 模块具有这两个功能并且我们被鼓励使用它们?

My question is, why the second approach is preferred? Is it only because torch.nn modules have those two function and we are encouraged to use them?

推荐答案

我找到了 这个页面在他们的 github repo 上,我会把内容贴在这里.

I've found this page on their github repo, I'll just paste the content here.

序列化和恢复模型有两种主要方法.

There are two main approaches for serializing and restoring a model.

第一个(推荐)只保存和加载模型参数:

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

The second saves and loads the entire model:

torch.save(the_model, PATH)

然后:

the_model = torch.load(PATH)

但是在这种情况下,序列化的数据绑定到特定的类以及使用的确切目录结构,因此它可以以各种方式中断在其他项目中使用,或者经过一些严重的重构.

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

这篇关于在 PyTorch 中保存训练模型的最佳方法?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆