VotingClassifier发生TypeError [英] An Typeerror with VotingClassifier

查看:132
本文介绍了VotingClassifier发生TypeError的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用VotingClassifier,但是交叉验证存在一些问题

I want to use VotingClassifier, but I have some problems with cross validating

    x_train, x_validation, y_train, y_validation = train_test_split(x, y, test_size=.22, random_state=2)
    x_train = x_train.fillna(0)
    clf1 = CatBoostClassifier()
    clf2 = RandomForestClassifier()
    clf = VotingClassifier(estimators=[('cb', clf1), ('rf', clf2)])
    clf.fit(x_train.values(), y_train)

soooo,我在预测...时出错.

soooo, I have an error with predicting...

    cross_validate(clf, x_train, y_train, scoring='accuracy', return_train_score = True, n_jobs = 4)


TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

(所有错误此处)

并在此处下载x_train和y_train↓

and download x_train and y_train here ↓

x_train
y_train

推荐答案

此错误是由于以下原因引起的:

This error is because of this line:

np.bincount(x, weights=self._weights_not_none)

此处x是VotingClassifier中各个分类器返回的预测.

Here x is the predictions returned by the individual classifiers inside the VotingClassifier.

根据 np.bincount :

计算非负数数组中每个值的出现次数 整数.

Count number of occurrences of each value in array of non-negative ints.

x:类似array_,一维,非负整数

x : array_like, 1 dimension, nonnegative ints

此方法仅需要数组中的int值.

This method requires only int values in the array.

现在,如果将CatBoostClassifier替换为任何其他Scikit学习分类器,则您的代码将可用.因为所有scikit-learn估计量都从其predict()返回np.int64的数组.

Now your code will work if you replace the CatBoostClassifier with any other Scikit-learn classifier. Because all scikit-learn estimators return array of np.int64 from their predict().

但是CatBoostClassifier返回np.float64作为输出.因此,错误.实际上,它还应该返回int64,因为predict()函数应该返回类,而不返回任何浮点值.但是我不知道为什么它会返回浮点数.

But CatBoostClassifier returns np.float64 as the output. And hence the error. Actually it should also return int64 because the predict() function should return the classes not any float values. But I dont know why it returns float.

您可以通过扩展CatBoostClassifier类并即时转换预测来纠正此问题.

You can correct this by extending the CatBoostClassifier class and converting the predictions on the fly.

import numpy as np
from catboost import CatBoostClassifier
class CatBoostClassifierInt(CatBoostClassifier):
    def predict(self, data, prediction_type='Class', ntree_start=0, ntree_end=0, thread_count=1, verbose=None):
        predictions = self._predict(data, prediction_type, ntree_start, ntree_end, thread_count, verbose)

        # This line is the only change I did
        return np.asarray(predictions, dtype=np.int64).ravel()

clf1 = CatBoostClassifierInt()
clf2 = RandomForestClassifier()
clf = VotingClassifier(estimators=[('cb', clf1), ('rf', clf2)])
cross_validate(clf, x_train, y_train, scoring='accuracy', return_train_score = True)

现在您不会收到该错误.

Now you wont get that error.

更正确的版本应该是这个.这将处理具有匹配输入和输出的所有类型的标签,并且可以轻松地在scikit中使用:

More correct version should be this. This will handle all the types of labels with matching input and output and can be used in scikit with ease:

class CatBoostClassifierCorrected(CatBoostClassifier):
    def fit(self, X, y=None, cat_features=None, sample_weight=None, baseline=None, use_best_model=None,
        eval_set=None, verbose=None, logging_level=None, plot=False, column_description=None, verbose_eval=None):

        self.le_ = LabelEncoder().fit(y)
        transformed_y = self.le_.transform(y)

        self._fit(X, transformed_y, cat_features, None, sample_weight, None, None, None, baseline, use_best_model, eval_set, verbose, logging_level, plot, column_description, verbose_eval)
        return self

    def predict(self, data, prediction_type='Class', ntree_start=0, ntree_end=0, thread_count=1, verbose=None):
        predictions = self._predict(data, prediction_type, ntree_start, ntree_end, thread_count, verbose)

        # This line is the only change I did
        return self.le_.inverse_transform(predictions.astype(np.int64))

这将处理所有不同类型的标签

This will handle all different types of labels

这篇关于VotingClassifier发生TypeError的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆