Handmade Estimator 修改 __init__ 中的参数? [英] Handmade Estimator modifies parameters in __init__?

查看:34
本文介绍了Handmade Estimator 修改 __init__ 中的参数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在准备一个定制的预处理阶段,它应该成为 sklearn.pipeline.Pipeline 的一部分.这是预处理器的代码:

I am preparing a tailored preprocessing phase which is suppose to become part of a sklearn.pipeline.Pipeline. Here's the code of the preprocessor:

import string
from nltk import wordpunct_tokenize
from nltk.stem.snowball import SnowballStemmer
from nltk import sent_tokenize
from sklearn.base import BaseEstimator, TransformerMixin
from . import stopwords

class NLTKPreprocessor(BaseEstimator, TransformerMixin):
    def __init__(self, stopwords=stopwords.STOPWORDS_DE,
                 punct=string.punctuation,
                 lower=True, strip=True, lang='german'):
        """
        Based on:
        https://bbengfort.github.io/tutorials/2016/05/19/text-classification-nltk-sckit-learn.html
        """

        self.lower = lower
        self.strip = strip
        self.stopwords = set(stopwords)
        self.punct = set(punct)
        self.stemmer = SnowballStemmer(lang)
        self.lang = lang

    def fit(self, X, y=None):
        return self

    def inverse_transform(self, X):
        return [" ".join(doc) for doc in X]

    def transform(self, X):
        return [
            list(self.tokenize(doc)) for doc in X
        ]

    def tokenize(self, document):
        # Break the document into sentences
        for sent in sent_tokenize(document, self.lang):
            for token in wordpunct_tokenize(sent):
                # Apply preprocessing to the token
                token = token.lower() if self.lower else token
                token = token.strip() if self.strip else token
                token = token.strip('_') if self.strip else token
                token = token.strip('*') if self.strip else token

                # If stopword, ignore token and continue
                if token in self.stopwords:
                    continue

                # If punctuation, ignore token and continue
                if all(char in self.punct for char in token):
                    continue

                # Lemmatize the token and yield
                # lemma = self.lemmatize(token, tag)
                stem = self.stemmer.stem(token)
                yield stem

接下来,这是我构建的管道:

Next, here is the pipeline I construct:

pipeline = Pipeline(
    [
        ('preprocess', nltkPreprocessor),
        ('vectorize', TfidfVectorizer(tokenizer=identity, preprocessor=None, lowercase=False)),
        ('clf', SGDClassifier(max_iter=1000, tol=1e-3))       
    ]
)

这一切都适用于单次传递;例如 pipeline.fit(X,y) 效果很好.但是,当将此管道放入网格搜索中时

This all works nicely for a single pass; for example pipeline.fit(X,y) works nicely. However, when putting this pipeline inside a grid search

parameters = {
    'vectorize__use_idf': (True, False),
    'vectorize__max_df': np.arange(0.8, 1.01 ,0.05),
    'vectorize__smooth_idf': (True, False),
    'vectorize__sublinear_tf': (True, False),
    'vectorize__norm': ('l1', 'l2'),
    'clf__loss':  ('hinge', 'log', 'modified_huber', 'squared_hinge', 'perceptron'),
    'clf__alpha': (0.00001, 0.000001),
    'clf__penalty': ('l1', 'l2', 'elasticnet')
}
grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1)
grid_search.fit(X_train, y_train)

我收到以下警告:

/Users/user/anaconda3/envs/myenv/lib/python3.6/site-packages/sklearn/base.py:115: DeprecationWarning: Estimator NLTKPreprocessor modifies parameters in __init__. This behavior is deprecated as of 0.18 and support for this behavior will be removed in 0.20.
  % type(estimator).__name__, DeprecationWarning)

我不明白在实现中应该更改/修复什么.如何维护功能并删除警告?

I don't understand what should be changed/fixed in the implementation. How can I maintain the functionality and remove the warning?

推荐答案

查看 sklearn 的开发者指南,这里以下段落.我会尽可能多地与它保持一致,以确保避免此类消息(即使您从未打算贡献它).

Check out the developer guide of sklearn, here and the following paragraph. I would try to cohere as much to it as possible to make sure such messages are avoided (even if you never intend to contribute it).

他们规定 __init__ 函数中的估算器应该没有逻辑!这很可能会导致您的错误.

They prescribe that estimators should have no logic in the __init__ function! This most likely causes your error.

我将 init 参数的验证或转换(如描述中也规定的)放在 fit() 方法的开头,无论如何都必须调用该方法.

I put my validation or transformation of init parameters (as prescribed also in the description) at the beginning of the fit() method, which has to be called in any case.

另外,请注意 this 实用程序,如果它确认了 scikit learn API,您可以使用它来测试您的估算器.

Also, note this utility which you can use to test your estimator if it confirms to the scikit learn API.

好吧,不是逻辑.从链接引用:总而言之,一个 __init__ 应该是这样的:

Well, not logic. To quote from the links: "To summarize, an __init__ should look like:

def __init__(self, param1=1, param2=2):
    self.param1 = param1
    self.param2 = param2

应该没有逻辑,甚至没有输入验证,并且不应该更改参数." 1

There should be no logic, not even input validation, and the parameters should not be changed." 1

所以我猜@uberwach 详细说明了 SnowballStemmer 实例的集合构造和创建可能违反了不应更改"部分.

So I guess as @uberwach detailed the set construction and creation of SnowballStemmer instance probably violates the "should not be changed"part.

作为以下评论的补充.这将是一种通用方法(@uberwach 稍后在您的标记化方法中提到的另一种具体方法):

As addition to the below comment. This would be one general way of doing it (another specific as mentioned by @uberwach later in your tokenize method):

class NLTKPreprocessor(BaseEstimator, TransformerMixin):
    def __init__(self, stopwords=stopwords.STOPWORDS_DE,
                 punct=string.punctuation,
                 lower=True, strip=True, lang='german'):
        self.lower = lower
        self.strip = strip
        self.stopwords = stopwords
        self.punct = punct
        self.lang = lang

    def fit(self, X, y=None):
        self.stopword_set = set(self.stopwords)
        self.punct_set = set(self.punct)
        self.stemmer = SnowballStemmer(self.lang)
        return self

这篇关于Handmade Estimator 修改 __init__ 中的参数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆