一个使用lambda函数的腌制任意pytorch模型如何? [英] How does one pickle arbitrary pytorch models that use lambda functions?

查看:318
本文介绍了一个使用lambda函数的腌制任意pytorch模型如何?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我目前有一个神经网络模块:

I currently have a neural network module:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

我正在尝试检查点,但是因为pytorch使用state_dict s保存,这意味着如果使用pytorch torch.save等检查点,我将无法保存我实际使用的lambda函数.发行并重新加载,以便以后在GPU上进行训练.我目前正在使用:

I am trying to checkpoint it but because pytorch saves using state_dicts it means I can't save the lambda functions I was actually using if I checkpoint with the pytorch torch.save etc. I literally want to save everything without issue and re-load to train on GPUs later. I currently am using this:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

目前,当我点它并保存它时,它不会引发任何错误.

currently it throws no errors when I chekpoint it and it saved it.

我担心即使在没有训练任何异常/错误的情况下进行训练时,也可能会有一个细微的错误,或者可能发生意料之外的事情(例如,奇怪地保存在群集中的磁盘上等等).

I am worried that when I train it there might be a subtle bug even if no exceptions/errors are trained or something unexpected might happen (e.g. weird saving on disks in the clusters etc who knows).

使用pytorch类/nn模型可以安全吗?特别是如果我们想恢复使用GPU的训练?

Is this safe to do with pytorch classes/nn models? Especially if we want to resume training with GPUs?

交叉发布:

  • How does one pickle arbitrary pytorch models that use lambda functions?
  • https://discuss.pytorch.org/t/how-does-one-pickle-arbitrary-pytorch-models-that-use-lambda-functions/79026
  • https://www.reddit.com/r/pytorch/comments/gagpjg/how_does_one_pickle_arbitrary_pytorch_models_that/?
  • https://www.quora.com/unanswered/How-does-one-pickle-arbitrary-PyTorch-models-that-use-lambda-functions

推荐答案

我是dill的作者.我使用dill(和klepto)将包含受训ANN的类保存在lambda函数中.我倾向于使用mysticsklearn的组合,因此我不能直接与pytorch对话,但是我可以假定它的工作原理相同.您必须要小心的地方是,如果您有一个lambda,其中包含指向该lambda外部对象的指针,例如y = 4; f = lambda x: x+y.这看起来似乎很明显,但是dill会腌制lambda,并且取决于代码的其余部分和序列化变体,可能不会序列化y的值.因此,我已经看到很多情况,人们在某个函数(或lambda或类)中序列化一个训练有素的估计器,然后当他们从序列化中恢​​复该函数时结果不是正确的".最主要的原因是因为未封装函数,所以函数产生正确结果所需的所有对象都存储在pickle中.但是,即使在那种情况下,您也可以获取正确的"结果,但是您只需要创建一个与对估算器进行腌制时相同的环境(即,它依赖于周围命名空间中的所有相同值).要解决的问题是,尝试确保在函数中定义了在函数中使用的所有变量.这是我最近开始使用自己的课程的一部分(应该在mystic的下一版本中):

I'm the dill author. I use dill (and klepto) to save classes that contain trained ANNs inside of lambda functions. I tend to use combinations of mystic and sklearn, so I can't speak directly to pytorch, but I can assume it works the same. The place where you have to be careful is if you have a lambda that contains a pointer to an object external to the lambda... so for example y = 4; f = lambda x: x+y. This might seem obvious, but dill will pickle the lambda, and depending on the rest of the code and the serialization variant, may not serialize the value of y. So, I've seen many cases where people serialize a trained estimator inside some function (or lambda, or class) and then the results aren't "correct" when they restore the function from serialization. The overarching cause is because the function wasn't encapsulated so all objects required for the function to yield the correct results are stored in the pickle. However, even in that case you can get the "correct" results back, but you'd just need to create the same environment you had when you pickled the estimator (i.e. all the same values it depends on in the surrounding namespace). The takeaway should be, try to make sure that all variables used in the function are defined within the function. Here's a portion of a class I've recently started to use myself (should be in the next release of mystic):

class Estimator(object):
    "a container for a trained estimator and transform (not a pipeline)"
    def __init__(self, estimator, transform):
        """a container for a trained estimator and transform

    Input:
        estimator: a fitted sklearn estimator
        transform: a fitted sklearn transform
        """
        self.estimator = estimator
        self.transform = transform
        self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
    def __call__(self, *x):
        "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
        import numpy as np
        return self.function(*x)

请注意,在调用该函数时,其使用的所有内容(包括np)都在周围的命名空间中定义.只要pytorch估计量可以按预期的顺序进行序列化(没有外部引用),那么遵循上述准则就可以了.

Note when the function is called, everything that it uses (including np) is defined in the surrounding namespace. As long as pytorch estimators serialize as expected (without external references), then you should be fine if you follow the above guidelines.

这篇关于一个使用lambda函数的腌制任意pytorch模型如何?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆