用它包装的函数保存一个sklearn`FunctionTransformer` [英] Saving an sklearn `FunctionTransformer` with the function it wraps

查看:248
本文介绍了用它包装的函数保存一个sklearn`FunctionTransformer`的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在将sklearnPipelineFunctionTransformer与自定义功能一起使用

I am using sklearn's Pipeline and FunctionTransformer with a custom function

from sklearn.externals import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline

这是我的代码:

def f(x):
    return x*2
pipe = Pipeline([("times_2", FunctionTransformer(f))])
joblib.dump(pipe, "pipe.joblib")
del pipe
del f
pipe = joblib.load("pipe.joblib") # Causes an exception

我收到此错误:

AttributeError:模块'__ main__'没有属性'f'

AttributeError: module '__ main__' has no attribute 'f'

如何解决?

请注意,此问题也在pickle

推荐答案

我能够使用marshal模块(除了pickle之外)破解解决方案,并覆盖魔术方法getstatesetstatepickle使用.

I was able to hack a solution using the marshal module (in addition to pickle) and override the magic methods getstate and setstate used by pickle.

import marshal
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin

class MyFunctionTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, f):
        self.func = f
    def __call__(self, X):
        return self.func(X)
    def __getstate__(self):
        self.func_name = self.func.__name__
        self.func_code = marshal.dumps(self.func.__code__)
        del self.func
        return self.__dict__
    def __setstate__(self, d):
        d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
        del d["func_name"]
        del d["func_code"]
        self.__dict__ = d
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        return self.func(X)

现在,如果我们使用MyFunctionTransformer而不是FunctionTransformer,则代码将按预期工作:

Now, if we use MyFunctionTransformer instead of FunctionTransformer, the code works as expected:

from sklearn.externals import joblib
from sklearn.pipeline import Pipeline

@MyFunctionTransformer
def my_transform(x):
    return x*2
pipe = Pipeline([("times_2", my_transform)])
joblib.dump(pipe, "pipe.joblib")
del pipe
del my_transform
pipe = joblib.load("pipe.joblib")

此方法的工作方式是从泡菜中删除功能f,而不是marshaling其代码和名称.

The way this works, is by deleting the function f from the pickle, and instead marshaling its code, and its name.

dill 看起来也像封送处理的好选择

dill also looks like a good alternative to marshaling

这篇关于用它包装的函数保存一个sklearn`FunctionTransformer`的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆