VotingClassifier 的类型错误 [英] Typeerror with VotingClassifier

查看:28
本文介绍了VotingClassifier 的类型错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用 VotingClassifier,但我在交叉验证方面遇到了一些问题

I want to use VotingClassifier, but I have some problems with cross validating

    x_train, x_validation, y_train, y_validation = train_test_split(x, y, test_size=.22, random_state=2)
    x_train = x_train.fillna(0)
    clf1 = CatBoostClassifier()
    clf2 = RandomForestClassifier()
    clf = VotingClassifier(estimators=[('cb', clf1), ('rf', clf2)])
    clf.fit(x_train.values(), y_train)

我在预测时出错...

    cross_validate(clf, x_train, y_train, scoring='accuracy', return_train_score = True, n_jobs = 4)


TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

(完整错误此处)

并在此处下载 x_train 和 y_train ↓

and download x_train and y_train here ↓

x_train
y_train

推荐答案

这个错误是因为这一行:

This error is because of this line:

np.bincount(x, weights=self._weights_not_none)

这里的 x 是 VotingClassifier 中各个分类器返回的预测.

Here x is the predictions returned by the individual classifiers inside the VotingClassifier.

根据np.bincount的文档:

计算每个值在非负数组中出现的次数整数.

Count number of occurrences of each value in array of non-negative ints.

x : array_like,一维,非负整数

x : array_like, 1 dimension, nonnegative ints

此方法只需要数组中的 int 值.

This method requires only int values in the array.

现在,如果您将 CatBoostClassifier 替换为任何其他 Scikit-learn 分类器,您的代码将起作用.因为所有 scikit-learn 估计器都从它们的 predict() 返回 np.int64 数组.

Now your code will work if you replace the CatBoostClassifier with any other Scikit-learn classifier. Because all scikit-learn estimators return array of np.int64 from their predict().

但是 CatBoostClassifier 返回 np.float64 作为输出.因此错误.实际上它也应该返回 int64,因为 predict() 函数应该返回类而不是任何浮点值.但我不知道为什么它返回浮动.

But CatBoostClassifier returns np.float64 as the output. And hence the error. Actually it should also return int64 because the predict() function should return the classes not any float values. But I dont know why it returns float.

您可以通过扩展 CatBoostClassifier 类并即时转换预测来纠正此问题.

You can correct this by extending the CatBoostClassifier class and converting the predictions on the fly.

import numpy as np
from catboost import CatBoostClassifier
class CatBoostClassifierInt(CatBoostClassifier):
    def predict(self, data, prediction_type='Class', ntree_start=0, ntree_end=0, thread_count=1, verbose=None):
        predictions = self._predict(data, prediction_type, ntree_start, ntree_end, thread_count, verbose)

        # This line is the only change I did
        return np.asarray(predictions, dtype=np.int64).ravel()

clf1 = CatBoostClassifierInt()
clf2 = RandomForestClassifier()
clf = VotingClassifier(estimators=[('cb', clf1), ('rf', clf2)])
cross_validate(clf, x_train, y_train, scoring='accuracy', return_train_score = True)

现在你不会得到那个错误了.

Now you wont get that error.

更正确的版本应该是这个.这将处理具有匹配输入和输出的所有类型的标签,并且可以轻松地在 scikit 中使用:

More correct version should be this. This will handle all the types of labels with matching input and output and can be used in scikit with ease:

class CatBoostClassifierCorrected(CatBoostClassifier):
    def fit(self, X, y=None, cat_features=None, sample_weight=None, baseline=None, use_best_model=None,
        eval_set=None, verbose=None, logging_level=None, plot=False, column_description=None, verbose_eval=None):

        self.le_ = LabelEncoder().fit(y)
        transformed_y = self.le_.transform(y)

        self._fit(X, transformed_y, cat_features, None, sample_weight, None, None, None, baseline, use_best_model, eval_set, verbose, logging_level, plot, column_description, verbose_eval)
        return self

    def predict(self, data, prediction_type='Class', ntree_start=0, ntree_end=0, thread_count=1, verbose=None):
        predictions = self._predict(data, prediction_type, ntree_start, ntree_end, thread_count, verbose)

        # This line is the only change I did
        return self.le_.inverse_transform(predictions.astype(np.int64))

这将处理所有不同类型的标签

This will handle all different types of labels

这篇关于VotingClassifier 的类型错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆