如何正确保存 pytorch 中的 torch.nn.Sequential 模型? [英] How does one save torch.nn.Sequential models in pytorch properly?

查看:29
本文介绍了如何正确保存 pytorch 中的 torch.nn.Sequential 模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我非常清楚加载字典然后有一个实例加载旧的参数字典(例如这个很好的问题和答案).不幸的是,当我有一个 torch.nn.Sequential 时,我当然没有它的类定义.

I am very well aware of loading the dictionary and then having a instance of be loaded with the old dictionary of parameters (e.g. this great question & answer). Unfortunately, when I have a torch.nn.Sequential I of course do not have a class definition for it.

所以我想仔细检查一下,正确的做法是什么.我相信 torch.save 就足够了(到目前为止我的代码还没有崩溃),尽管这些事情可能比人们想象的更微妙(例如,我在使用 pickle 时收到警告,但 torch.save 在内部使用它,所以很混乱).此外,numpy 有它自己的保存功能(例如,参见 这个答案),这往往更有效,所以可能会有我可能会忽略一些微妙的权衡.

So I wanted to double check, what is the proper way to do it. I believe torch.save is sufficient (so far my code has not collapsed), though these things can be more subtle than one might expect (e.g. I get a warning when I use pickle but torch.save uses it internally so it's confusing). Also, numpy has it's own save functions (e.g. see this answer) which tend to be more efficient, so there might be a subtle trade off I might be overlooking.

我的测试代码:


# creating data and running through a nn and saving it

import torch
import torch.nn as nn

from pathlib import Path
from collections import OrderedDict

import numpy as np

import pickle

path = Path('~/data/tmp/').expanduser()
path.mkdir(parents=True, exist_ok=True)

num_samples = 3
Din, Dout = 1, 1
lb, ub = -1, 1

x = torch.torch.distributions.Uniform(low=lb, high=ub).sample((num_samples, Din))

f = nn.Sequential(OrderedDict([
    ('f1', nn.Linear(Din,Dout)),
    ('out', nn.SELU())
]))
y = f(x)

# save data torch to numpy
x_np, y_np = x.detach().cpu().numpy(), y.detach().cpu().numpy()
np.savez(path / 'db', x=x_np, y=y_np)

print(x_np)
# save model
with open('db_saving_seq', 'wb') as file:
    pickle.dump({'f': f}, file)

# load model
with open('db_saving_seq', 'rb') as file:
    db = pickle.load(file)
    f2 = db['f']

# test that it outputs the right thing
y2 = f2(x)

y_eq_y2 = y == y2
print(y_eq_y2)

db2 = {'f': f, 'x': x, 'y': y}
torch.save(db2, path / 'db_f_x_y')

print('Done')

db3 = torch.load(path / 'db_f_x_y')
f3 = db3['f']
x3 = db3['x']
y3 = db3['y']
yy3 = f3(x3)

y_eq_y3 = y == y3
print(y_eq_y3)

y_eq_yy3 = y == yy3
print(y_eq_yy3)


相关:


Related:

推荐答案

从代码中可以看出torch.nn.Sequential是基于torch.nn.Module:https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential

As can be seen in the code torch.nn.Sequential is based on torch.nn.Module: https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential

所以你可以使用

f = torch.nn.Sequential(...)
torch.save(f.state_dict(), path)

就像任何其他torch.nn.Module一样.

这篇关于如何正确保存 pytorch 中的 torch.nn.Sequential 模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆