用于多标签分类的 XGBoost? [英] XGBoost for multilabel classification?
问题描述
是否可以使用 XGBoost 进行多标签分类?现在我使用 OneVsRestClassifier
而不是 sklearn
的 GradientBoostingClassifier
.它可以工作,但只使用我 CPU 的一个内核.在我的数据中,我有大约 45 个特征,任务是用二进制(布尔)数据预测大约 20 列.指标是平均精度 (map@7).如果您有一个简短的代码示例要分享,那就太好了.
Is it possible to use XGBoost for multi-label classification? Now I use OneVsRestClassifier
over GradientBoostingClassifier
from sklearn
. It works, but use only one core from my CPU. In my data I have ~45 features and the task is to predict about 20 columns with binary (boolean) data. Metric is mean average precision (map@7). If you have a short example of code to share, that would be great.
推荐答案
一种可能的方法,而不是使用用于多类任务的 OneVsRestClassifier
,是使用 MultiOutputClassifier
> 来自 sklearn.multioutput
模块.
One possible approach, instead of using OneVsRestClassifier
which is for multi-class tasks, is to use MultiOutputClassifier
from the sklearn.multioutput
module.
下面是一个小的可重现的示例代码,其中包含 OP 请求的输入特征和目标输出的数量
Below is a small reproducible sample code with the number of input features and target outputs requested by the OP
import xgboost as xgb
from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import accuracy_score
# create sample dataset
X, y = make_multilabel_classification(n_samples=3000, n_features=45, n_classes=20, n_labels=1,
allow_unlabeled=False, random_state=42)
# split dataset into training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)
# create XGBoost instance with default hyper-parameters
xgb_estimator = xgb.XGBClassifier(objective='binary:logistic')
# create MultiOutputClassifier instance with XGBoost model inside
multilabel_model = MultiOutputClassifier(xgb_estimator)
# fit the model
multilabel_model.fit(X_train, y_train)
# evaluate on test data
print('Accuracy on test data: {:.1f}%'.format(accuracy_score(y_test, multilabel_model.predict(X_test))*100))
这篇关于用于多标签分类的 XGBoost?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!