继承自 scikit-learn 的 LassoCV 模型 [英] Inherit from scikit-learn's LassoCV model

查看:43
本文介绍了继承自 scikit-learn 的 LassoCV 模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试使用继承扩展 scikit-learn 的 RidgeCV 模型:

I tried to extend scikit-learn's RidgeCV model using inheritance:

from sklearn.linear_model import RidgeCV, LassoCV

class Extended(RidgeCV):
    def __init__(self, *args, **kwargs):
        super(Extended, self).__init__(*args, **kwargs)

    def example(self):
        print 'Foo'


x = [[1,0],[2,0],[3,0],[4,0], [30, 1]]
y = [2,4,6,8, 60]
model = Extended(alphas = [float(a)/1000.0 for a in range(1, 10000)])
model.fit(x,y)
print model.predict([[5,1]])

它工作得很好,但是当我尝试从 LassoCV 继承时,它产生了以下回溯:

It worked perfectly fine, but when I tried to inherit from LassoCV, it yielded the following traceback:

Traceback (most recent call last):
  File "C:/Python27/so.py", line 14, in <module>
    model.fit(x,y)
  File "C:\Python27\lib\site-packages\sklearn\linear_model\coordinate_descent.py", line 1098, in fit
    path_params = self.get_params()
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 214, in get_params
    for key in self._get_param_names():
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 195, in _get_param_names
    % (cls, init_signature))
RuntimeError: scikit-learn estimators should always specify their parameters in the signature of their __init__ (no varargs). <class '__main__.Extended'> with constructor (<self>, *args, **kwargs) doesn't  follow this convention.

有人能解释一下如何解决这个问题吗?

Can somebody explain how to fix this?

推荐答案

您可能想要制作与 scikit-learn 兼容的模型,以便进一步与可用的 scikit-learn 功能一起使用.如果你这样做 - 你需要先阅读这个:http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

You probably want to make scikit-learn compatible model, to use it further with available scikit-learn functional. If you do - you need to read this first: http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

Shortly:scikit-learn 有很多特性,比如估计器克隆(clone() 函数),元算法,如 GridSearchPipeline、交叉验证.所有这些东西都必须能够获取估算器内部字段的值,并更改这些字段的值(例如 GridSearch 必须在每次评估之前更改估算器内部的参数),例如参数SGDClassifier 中的 alpha.要更改某个参数的值,它必须知道其名称.要从 BaseEstimator 类(您隐式继承)中的每个分类器方法 get_params 中获取所有字段的名称,需要在 __init__ 一个类的方法,因为很容易反省__init__方法的所有参数名(看BaseEstimator,这是抛出这个错误的类).

Shortly: scikit-learn has many features like estimator cloning (clone() function), meta algorithms like GridSearch, Pipeline, Cross validation. And all these things have to be able to get values of fields inside of your estimator, and change value of these fields (For example GridSearch has to change parameters inside of your estimator before each evaluation), like parameter alpha in SGDClassifier. To change value of some parameter it has to know its name. To get names of all fields in every classifier method get_params from BaseEstimator class (Which you're inheriting implicitly) requires all parameters to be specified in __init__ method of a class, because it's easy to introspect all parameter names of __init__ method (Look at BaseEstimator, this is the class which throws this error).

所以它只是想让你删除所有的可变参数

So it just wants you to remove all varargs like

*args, **kwargs

来自 __init__ 签名.您必须在 __init__ 签名中列出模型的所有参数,并初始化对象的所有内部字段.

from __init__ signature. You have to list all parameters of your model in __init__ signature, and initialise all internal fields of an object.

这是SGDClassifier,继承自BaseSGDClassifier:

Here is example of __init__ method of SGDClassifier, which is inherited from BaseSGDClassifier:

def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15,
             fit_intercept=True, n_iter=5, shuffle=True, verbose=0,
             epsilon=DEFAULT_EPSILON, n_jobs=1, random_state=None,
             learning_rate="optimal", eta0=0.0, power_t=0.5,
             class_weight=None, warm_start=False, average=False):
    super(SGDClassifier, self).__init__(
        loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio,
        fit_intercept=fit_intercept, n_iter=n_iter, shuffle=shuffle,
        verbose=verbose, epsilon=epsilon, n_jobs=n_jobs,
        random_state=random_state, learning_rate=learning_rate, eta0=eta0,
        power_t=power_t, class_weight=class_weight, warm_start=warm_start, average=average)

这篇关于继承自 scikit-learn 的 LassoCV 模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆