在无法访问模型类代码的情况下保存 PyTorch 模型 [英] Saving PyTorch model with no access to model class code

查看:25
本文介绍了在无法访问模型类代码的情况下保存 PyTorch 模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何在不需要在某处定义模型类的情况下保存 PyTorch 模型?

How can I save a PyTorch model without a need for the model class to be defined somewhere?

免责声明:

在 PyTorch 中保存训练模型的最佳方法?没有解决方案(或有效的解决方案)可以在不访问模型类代码的情况下保存模型.

In Best way to save a trained model in PyTorch?, there are no solutions (or a working solution) for saving the model without access to the model class code.

推荐答案

如果您打算使用可用的 Pytorch 库(即 Python、C++ 或它支持的其他平台中的 Pytorch)进行推理,那么最好的方法是通过 TorchScript.

If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.

我认为最简单的方法是使用 trace = torch.jit.trace(model,typical_input)torch.jit.save(trace, path).然后,您可以使用 torch.jit.load(path) 加载跟踪模型.

I think the simplest thing is to use trace = torch.jit.trace(model, typical_input) and then torch.jit.save(trace, path). You can then load the traced model with torch.jit.load(path).

这是一个非常简单的例子.我们制作了两个文件:

Here's a really simple example. We make two files:

train.py :

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
    print(model(x))
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")

infer.py :

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
    print(loaded_trace(x))

按顺序运行这些会给出结果:

Running these sequentially gives results:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

结果是一样的,所以我们很好.(注意,由于nn.Linear层初始化的随机性,这里每次的结果都会不同).

The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).

TorchScript 提供了将更复杂的架构和图形定义(包括 if 语句、while 循环等)保存在单个文件中,而无需在推理时重新定义图形.有关更高级的可能性,请参阅文档(上面链接).

TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.

这篇关于在无法访问模型类代码的情况下保存 PyTorch 模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆