用于实现接口的类的集合的重写方法 [英] Override method for a collection of classes implementing an interface
问题描述
我正在使用scikit-learn并正在构建管道.构建管道之后,我将使用GridSearchCV查找最佳模型.我正在处理文本数据,因此正在尝试不同的词干分析器.我创建了一个称为Preprocessor的类,该类采用了茎和向量化器类,然后尝试覆盖向量化器的方法build_analyzer来合并给定的茎.但是,我看到GridSearchCV的set_params只是直接访问实例变量-即,它不会像我一直在用新的分析器重新实例化矢量化器:
I am using scikit-learn and am building a pipeline. Once the pipeline is build, I am using GridSearchCV to find the optimal model. I am working with text data, so I am experimenting with different stemmers. I have created a class called Preprocessor that takes a stemmer and vectorizer class, then attempts to override the vectorizer's method build_analyzer to incorporate the given stemmer. However, I see that GridSearchCV's set_params just directly accesses instance variables -- i.e. it will not re-instantiate a vectorizer with a new analyzer, as I have been doing it:
class Preprocessor(object):
# hard code the stopwords for now
stopwords = nltk.corpus.stopwords.words()
def __init__(self, stemmer_cls, vectorizer_cls):
self.stemmer = stemmer_cls()
analyzer = self._build_analyzer(self.stemmer, vectorizer_cls)
self.vectorizer = vectorizer_cls(stopwords=stopwords,
analyzer=analyzer,
decode_error='ignore')
def _build_analyzer(self, stemmer, vectorizer_cls):
# analyzer tokenizes and lowercases
analyzer = super(vectorizer_cls, self).build_analyzer()
return lambda doc: (stemmer.stem(w) for w in analyzer(doc))
def fit(self, **kwargs):
return self.vectorizer.fit(kwargs)
def transform(self, **kwargs):
return self.vectorizer.transform(kwargs)
def fit_transform(self, **kwargs):
return self.vectorizer.fit_transform(kwargs)
所以问题是:如何为传入的所有矢量化器类覆盖build_analyzer?
So the question is: how can I override a build_analyzer for all vectorizer classes passed in?
推荐答案
是的,GridSearchCV直接设置实例字段,然后在具有更改字段的分类器上调用fit.
Yes, GridSearchCV directly sets instance fields, and then calls fit on classifier with changed fields.
scikit-learn中的每个分类器都是以这种方式构建的,即__init__
仅设置参数字段,而进一步工作所需的所有依赖对象(例如在您的情况下调用_build_analyzer)仅在fit方法内部构造.您必须添加额外的字段来存储vectorizer_cls,然后必须在fit方法中构造与vectorized_cls和stemmer_cls对象相关的对象.
Every classifier in scikit-learn was built in such a way, that __init__
only sets parameter fields, and all dependent objects needed for further work (like calling _build_analyzer in your case) is constructed only inside fit method. You have to add additional field which stores vectorizer_cls, then you have to construct dependent from vectorized_cls and stemmer_cls objects inside fit method.
类似的东西:
class Preprocessor(object):
# hard code the stopwords for now
stopwords = nltk.corpus.stopwords.words()
def __init__(self, stemmer_cls, vectorizer_cls):
self.stemmer_cls = stemmer_cls
self.vectorizer_cls = vectorizer_cls
def _build_analyzer(self, stemmer, vectorizer_cls):
# analyzer tokenizes and lowercases
analyzer = super(vectorizer_cls, self).build_analyzer()
return lambda doc: (stemmer.stem(w) for w in analyzer(doc))
def fit(self, **kwargs):
analyzer = self._build_analyzer(self.stemmer_cls(), vectorizer_cls)
self.vectorizer_cls = vectorizer_cls(stopwords=stopwords,
analyzer=analyzer,
decode_error='ignore')
return self.vectorizer_cls.fit(kwargs)
def transform(self, **kwargs):
return self.vectorizer_cls.transform(kwargs)
def fit_transform(self, **kwargs):
return self.vectorizer_cls.fit_transform(kwargs)
这篇关于用于实现接口的类的集合的重写方法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!