Sklearn StackingClassifier:将要素添加为最终估算器的输入 [英] Sklearn StackingClassifier: Adding features as inputs to the final estimator

查看:149
本文介绍了Sklearn StackingClassifier:将要素添加为最终估算器的输入的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用管道和堆栈分类器来构建分类管道.在我的设置中,我想将一些额外的原始功能以及上一级模型的预测传递给最终估计器.以图表的方式,如下所示:

I am using pipelines and stacking classifiers to construct a classification pipeline. In my setup, I would like to pass some extra raw features to the final estimator, along with predictions of the previous level's models. Diagrammatically, this would like like below:

我仍然想利用这两个管道(除了添加 Feat x/y 之外,我已经用来设置所有内容)和

I would still like to leverage both pipelines (which I've used to set everything up except the adding of Feat x/y) and StackingClassifier to do this, since it handles training a stacked model end-to-end quite cleanly. However, I dont see an option to add raw features to the predictions of the previous "level"'s models. Is there a good way to do this?

注意:输入到最终估计量的特征与输入到模型1和模型2的特征不同,因此我不是在寻找 pass_through = True 标志.

Note: The features input to the final estimator are not the same as the features input to model 1 and model 2, so I'm not looking for the pass_through=True flag.

推荐答案

此功能不是快速功能,但我可以想到两种方法,在仍使用 StackingClassifier的情况下将其组合在一起自动化.每个都有一些缺点.

This isn't available as a quick feature, but I can think of two ways to piece it together while still making use of the StackingClassifier automation. Each one comes with some downsides.

制作一个虚拟的预测变量,即预测"目标变量.只需返回输入,并将其用作基本估算器,即可将额外的功能传递给元估算器.使用 ColumnTransformer 选择基本估计器的功能或直通功能.

Make a dummy predictor, that "predicts" by just returning the input, and use that as a base estimator to get the extra features through to the meta-estimator. Use ColumnTransformer to pick out either the base estimators' features or the passthrough features.

from sklearn.base import ClassifierMixin, TransformerMixin
from sklearn.pipeline import Pipeline

class IdentityPassthrough(ClassifierMixin):
    def __init__(self):
        pass
    def fit(self, X, y):
        return self
    def predict(self, X):
        return X

partial_passthrough = Pipeline([
    ('pass', ColumnTransformer([('pass', 'passthrough', ['x', 'y'])])),
    ('ident', IdentityPassthrough()),
])
base_features = ColumnTransformer([('pass', 'passthrough', ['a', 'b'])])

model = StackingClassifier(estimators=[
        ('pass', partial_passthrough),
        ('tree', Pipeline([('select', base_features), ('tree', DecisionTreeClassifier())])),
        ('knn', Pipeline([('select', base_features), ('knn', KNeighborsClassifier())])),
    ])

model.fit(X, y)

使用直通并选择基本特征

对元估计器使用组合,从基础估计器和所需的额外特征中选择要预测的特征.这有点令人担忧,因为您必须知道正确的列顺序(直到sklearn完成处理功能名称).也就是说,在下面的代码中,特征0和1是预测的概率(并且,如果堆栈方法不是排他地 predict_proba ,由于类预测为负,则必须为1和3!),而4和5是有针对性的直通变量(在原始框架中分别被索引为2和3).

Use passthrough and select out base features

Use a composite for the meta-estimator, that selects down the features to the predictions from the base estimators and the desired extra features. This is a little worrisome, because you have to know you're getting the right order of columns (until sklearn finishes dealing with feature names). I.e., in the below code, features 0 and 1 are the predicted probabilities (and if the stacking method weren't exclusively predict_proba, that would need to be 1 and 3 because of the negative class predictions!), and 4 and 5 are the targeted passthrough variables (which were indexed 2 and 3 in the original frame).

base_features = ColumnTransformer([('pass', 'passthrough', ['mean radius', 'mean texture'])])

model = StackingClassifier(
    estimators=[
        ('tree', Pipeline([('select', base_features), ('tree', DecisionTreeClassifier(random_state=42))])),
        ('knn', Pipeline([('select', base_features), ('knn', KNeighborsClassifier())])),
    ],
    final_estimator=Pipeline([
        ('select', ColumnTransformer([('select', 'passthrough', [0, 1, 4, 5])])),
        ('model', LogisticRegression())
    ]),
    passthrough=True,
)

model.fit(X, y)

模型图:

<!-- style defs (common to the two exports from estimator_html_repr) -->
<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style>

<!-- First approach diagram: -->
<div class="sk-top-container"><div class="sk-container"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="a4c7712b-4e69-42ca-b31f-ecbe7b6d1d89" type="checkbox" ><label class="sk-toggleable__label" for="a4c7712b-4e69-42ca-b31f-ecbe7b6d1d89">StackingClassifier</label><div class="sk-toggleable__content"><pre>StackingClassifier(estimators=[('pass',                 Pipeline(steps=[('pass',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'perimeter',                                          'mean '                                          'area'])])),                         ('ident',                         <__main__.IdentityPassthrough object at 0x7f2bfbf1f358>)])),                ('tree',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('tree',                         DecisionTreeClassifier(random_state=42))])),                ('knn',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('knn',                         KNeighborsClassifier())]))])</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item"><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>pass</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="c41ba408-2542-42cf-be5d-d2bdb1f7ca39" type="checkbox" ><label class="sk-toggleable__label" for="c41ba408-2542-42cf-be5d-d2bdb1f7ca39">pass: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean perimeter', 'mean area'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="d71f311c-d151-402b-80a4-97fcb9464d8f" type="checkbox" ><label class="sk-toggleable__label" for="d71f311c-d151-402b-80a4-97fcb9464d8f">pass</label><div class="sk-toggleable__content"><pre>['mean perimeter', 'mean area']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="3e9e3a12-5622-4e56-9171-dbc690ca50d8" type="checkbox" ><label class="sk-toggleable__label" for="3e9e3a12-5622-4e56-9171-dbc690ca50d8">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="b5753df0-b293-4aeb-bbdb-d9adc63b6ac8" type="checkbox" ><label class="sk-toggleable__label" for="b5753df0-b293-4aeb-bbdb-d9adc63b6ac8">IdentityPassthrough</label><div class="sk-toggleable__content"><pre><__main__.IdentityPassthrough object at 0x7f2bfbf1f358></pre></div></div></div></div></div></div></div></div><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>tree</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="4952e340-a144-40cd-897b-dcdee029fecb" type="checkbox" ><label class="sk-toggleable__label" for="4952e340-a144-40cd-897b-dcdee029fecb">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="7efbe86f-2262-4048-81fd-7c652803cf4f" type="checkbox" ><label class="sk-toggleable__label" for="7efbe86f-2262-4048-81fd-7c652803cf4f">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="7ccdbf06-9312-4424-a74e-e1c56b3fbe88" type="checkbox" ><label class="sk-toggleable__label" for="7ccdbf06-9312-4424-a74e-e1c56b3fbe88">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="3a2df016-3631-4cc6-960c-695466268875" type="checkbox" ><label class="sk-toggleable__label" for="3a2df016-3631-4cc6-960c-695466268875">DecisionTreeClassifier</label><div class="sk-toggleable__content"><pre>DecisionTreeClassifier(random_state=42)</pre></div></div></div></div></div></div></div></div><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>knn</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="5aaa1a50-3954-43c0-802d-0679ecfaaa5f" type="checkbox" ><label class="sk-toggleable__label" for="5aaa1a50-3954-43c0-802d-0679ecfaaa5f">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="836d9068-cfb3-4545-a714-6f349403d567" type="checkbox" ><label class="sk-toggleable__label" for="836d9068-cfb3-4545-a714-6f349403d567">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="0c9e6b1f-0f96-4d6f-8efc-e3bada46d6a7" type="checkbox" ><label class="sk-toggleable__label" for="0c9e6b1f-0f96-4d6f-8efc-e3bada46d6a7">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="56bfccc3-bca0-4d87-a377-a81913e4098c" type="checkbox" ><label class="sk-toggleable__label" for="56bfccc3-bca0-4d87-a377-a81913e4098c">KNeighborsClassifier</label><div class="sk-toggleable__content"><pre>KNeighborsClassifier()</pre></div></div></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="57737187-5f4c-4186-ad65-e68cecfe14e8" type="checkbox" ><label class="sk-toggleable__label" for="57737187-5f4c-4186-ad65-e68cecfe14e8">LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div></div></div></div></div></div></div>

<!-- Second approach diagram: -->
<div class="sk-top-container"><div class="sk-container"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="034d2534-0d63-4319-bfbf-3b0a7117e00f" type="checkbox" ><label class="sk-toggleable__label" for="034d2534-0d63-4319-bfbf-3b0a7117e00f">StackingClassifier</label><div class="sk-toggleable__content"><pre>StackingClassifier(estimators=[('tree',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('tree',                         DecisionTreeClassifier(random_state=42))])),                ('knn',                 Pipeline(steps=[('select',                         ColumnTransformer(transformers=[('pass',                                          'passthrough',                                          ['mean '                                          'radius',                                          'mean '                                          'texture'])])),                         ('knn',                         KNeighborsClassifier())]))],          final_estimator=Pipeline(steps=[('select',                           ColumnTransformer(transformers=[('select',                                           'passthrough',                                           [0,                                            1,                                            4,                                            5])])),                          ('model',                           LogisticRegression())]),          passthrough=True)</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item"><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>tree</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="50985202-3021-4333-877c-034e62c6e07a" type="checkbox" ><label class="sk-toggleable__label" for="50985202-3021-4333-877c-034e62c6e07a">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="8132ce46-3e0b-42d2-b42b-f0a53d192c07" type="checkbox" ><label class="sk-toggleable__label" for="8132ce46-3e0b-42d2-b42b-f0a53d192c07">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="e1970d86-b28e-41d0-8297-4d2ed67b4b50" type="checkbox" ><label class="sk-toggleable__label" for="e1970d86-b28e-41d0-8297-4d2ed67b4b50">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="fafe6dec-d6a7-4c00-b561-17f3307e4bde" type="checkbox" ><label class="sk-toggleable__label" for="fafe6dec-d6a7-4c00-b561-17f3307e4bde">DecisionTreeClassifier</label><div class="sk-toggleable__content"><pre>DecisionTreeClassifier(random_state=42)</pre></div></div></div></div></div></div></div></div><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><label>knn</label></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="8d608e7c-c318-4a67-a9b7-26995a77bcc6" type="checkbox" ><label class="sk-toggleable__label" for="8d608e7c-c318-4a67-a9b7-26995a77bcc6">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('pass', 'passthrough',                 ['mean radius', 'mean texture'])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="12515389-377c-4fd1-9b50-cf0515dc1919" type="checkbox" ><label class="sk-toggleable__label" for="12515389-377c-4fd1-9b50-cf0515dc1919">pass</label><div class="sk-toggleable__content"><pre>['mean radius', 'mean texture']</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="4fa202f1-74c0-47ba-b34d-dbc4e460eff9" type="checkbox" ><label class="sk-toggleable__label" for="4fa202f1-74c0-47ba-b34d-dbc4e460eff9">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="0ac195fa-6584-4220-a00e-8da1dd09b5de" type="checkbox" ><label class="sk-toggleable__label" for="0ac195fa-6584-4220-a00e-8da1dd09b5de">KNeighborsClassifier</label><div class="sk-toggleable__content"><pre>KNeighborsClassifier()</pre></div></div></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-serial"><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="ede1e052-c7af-4bd3-9da9-2fcc69dd8c86" type="checkbox" ><label class="sk-toggleable__label" for="ede1e052-c7af-4bd3-9da9-2fcc69dd8c86">select: ColumnTransformer</label><div class="sk-toggleable__content"><pre>ColumnTransformer(transformers=[('select', 'passthrough', [0, 1, 4, 5])])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="8ad23c33-3176-45c6-9504-98299d187eda" type="checkbox" ><label class="sk-toggleable__label" for="8ad23c33-3176-45c6-9504-98299d187eda">select</label><div class="sk-toggleable__content"><pre>[0, 1, 4, 5]</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="9fa0446e-ea2b-4c32-b32b-07c9b2643717" type="checkbox" ><label class="sk-toggleable__label" for="9fa0446e-ea2b-4c32-b32b-07c9b2643717">passthrough</label><div class="sk-toggleable__content"><pre>passthrough</pre></div></div></div></div></div></div></div></div><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="450a759e-b194-4bf9-a92b-b296f6c9f527" type="checkbox" ><label class="sk-toggleable__label" for="450a759e-b194-4bf9-a92b-b296f6c9f527">LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div></div></div></div></div></div></div></div></div>

查看有关威斯康星州乳腺癌数据集的全部内容,在这个笔记本.

See the whole thing in action on the Wisconsin breast cancer dataset in this notebook.

这篇关于Sklearn StackingClassifier:将要素添加为最终估算器的输入的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆