scikit学习:与GridSearchCV兼容的自定义分类器 [英] scikit learn: custom classifier compatible with GridSearchCV

查看:277
本文介绍了scikit学习:与GridSearchCV兼容的自定义分类器的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经实现了自己的分类器,现在我想对其进行网格搜索,但是出现以下错误:estimator.fit(X_train, y_train, **fit_params) TypeError: fit() takes 2 positional arguments but 3 were given

I have implemented my own classifier and now I want to run a grid search over it, but I'm getting the following error: estimator.fit(X_train, y_train, **fit_params) TypeError: fit() takes 2 positional arguments but 3 were given

我遵循了本教程,并使用了此模板

I followed this tutorial and used this template provided by scikit's official documentation. My class is defined as follows:

class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, lr=0.1):
        self.lr=lr

    def fit(self, X, y):
        # Some code
        return self
    def predict(self, X):
        # Some code
        return y_pred
    def get_params(self, deep=True)
        return {'lr'=self.lr}
    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

我正在尝试通过网格搜索将其抛出,如下所示:

And I'm trying to grid search throw it as follows:

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

编辑我

这就是我的称呼: gs.fit(['hello world','trying','hello world','trying','hello world','trying','hello world','trying'], ['I','Z','I','Z','I','Z','I','Z'])

This is how I'm calling it: gs.fit(['hello world', 'trying','hello world', 'trying', 'hello world', 'trying', 'hello world', 'trying'], ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])

END EDIT I

该错误是由文件python3.5/site-packages/sklearn/model_selection/_validation.py

它正在调用带有3个参数的estimator.fit(X_train, y_train, **fit_params),但是我的估计量只有两个,因此该错误对我来说很有意义,但是我不知道如何解决...我还尝试向fit方法,但无效.

It is calling estimator.fit(X_train, y_train, **fit_params) with 3 arguments, but my estimator only have two, so the error makes sense for me, but I don't know how to solve it... I also tried adding some dummy arguments to fit method but it didn't work.

EDIT II

完整的错误输出:

Traceback (most recent call last):
  File "/home/rodrigo/no_version/text_classifier/MyClassifier.py", line 355, in <module>
    ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_search.py", line 639, in fit
    cv.split(X, y, groups)))
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch
    self._dispatch(tasks)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async
    result = ImmediateResult(func)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__
    self.results = batch()
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_validation.py", line 458, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
TypeError: fit() takes 2 positional arguments but 3 were given

END EDIT II

已解决 谢谢大家,我犯了一个愚蠢的错误:有两个具有相同名称(拟合)的不同函数,(我为另一个目的使用不同的参数实现了自定义,当我重命名自定义拟合"后,它就可以正常工作.)

SOLVED Thanks you all, I had a stupid mistake: there was two different functions with same name (fit), (I implemented the other for custom purposes with different parameters, as soon as I renamed my 'custom fit', it worked correctly.)

谢谢,抱歉

推荐答案

以下代码对我有用:

class MyClassifier(BaseEstimator, ClassifierMixin):
     def __init__(self, lr=0.1):
         # Some code
         pass
     def fit(self, X, y):
         # Some code
         pass
     def predict(self, X):
         # Some code
         return X % 3

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)

我能想到的最好的办法是,您正在将某些内容传递给gs.fit方法中的xy之外,或者您的MyClassifier.fit方法缺少self参数.

The best I can figure is that you are passing something into the gs.fit method beyond x and y or your MyClassifier.fit method is missing the self argument.

仅当您将kwarg传递给gs.fit方法时,才应填充fit_params kwargs,否则它是一个空字典({}),并且**fit_params不会引发参数错误.要对此进行测试,请创建您的分类器实例并传递**{}.例如:

The fit_params kwargs should only be populated if you pass a kwarg to the gs.fit method otherwise it is an empty dictionary ({}) and **fit_params won't throw an argument error. To test this create an instance of your classifier and pass **{}. For example:

clf = MyClassifier()
clf.fit(x, y, **{})

这不会引发位置参数错误.

This does not throw the positional arguments error.

因此,再次除非将某些内容传递给gs.fit,例如gs.fit(x, y, some_arg=123)在我看来,您缺少MyClassifier.fit定义中的位置参数之一.您包含的错误消息似乎支持该假设,因为它陈述了fit() takes 2 positional arguments but 3 were given.如果您按照以下方式定义拟合,则将需要3个位置参数:

Therefore, again unless something is passed to gs.fit e.g. gs.fit(x, y, some_arg=123) it would seem to me that you are missing one of the positional arguments in the definition of MyClassifier.fit. The error message you included seems to support this hypothesis as it states fit() takes 2 positional arguments but 3 were given. If you had defined fit as follows it would take 3 positional arguments:

def fit(self, X, y): ...

这篇关于scikit学习:与GridSearchCV兼容的自定义分类器的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆