带有自定义评估函数的 Python 中 xgboost 的意外行为 [英] Unexpected behavior from xgboost in Python with Custom Evaluation Function

查看:35
本文介绍了带有自定义评估函数的 Python 中 xgboost 的意外行为的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用带有自定义评估函数的 xgboost,我想实现提前停止,设置限制为 150 轮.

I am using xgboost with a Custom Evaluation Function and I would like to implement Early Stopping setting a limit of 150 rounds.

我得到的评价指标比预期的 2 个多 4 个,但我不知道如何解释它们.此外,我不确定如何激活提前停止设置限制(例如,150 轮).

I am getting back 4 evaluation metrics than the expected 2 and I do not know how to interpret them. Moreover I am not sure how to activate early stopping setting a limit as well (e.g., 150 rounds).

对于可重现的示例:

import numpy as np

def F1_eval_gen(preds, labels):
    t = np.arange(0, 1, 0.005)
    f = np.repeat(0, 200)
    results = np.vstack([t, f]).T
    # assuming labels only containing 0's and 1's
    n_pos_examples = sum(labels)
    if n_pos_examples == 0:
        n_pos_examples = 1

    for i in range(200):
        pred_indexes = (preds >= results[i, 0])
        TP = sum(labels[pred_indexes])
        FP = len(labels[pred_indexes]) - TP
        precision = 0
        recall = TP / n_pos_examples

        if (FP + TP) > 0:
            precision = TP / (FP + TP)

        if (precision + recall > 0):
            F1 = 2 * precision * recall / (precision + recall)
        else:
            F1 = 0
        results[i, 1] = F1
    return (max(results[:, 1]))

def F1_eval(preds, dtrain):
    res = F1_eval_gen(preds, dtrain.get_label())
    return 'f1_err', 1-res

from sklearn import datasets
from sklearn.model_selection import *

skl_data = datasets.load_breast_cancer()

X = skl_data.data

y = skl_data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

scale_pos_weight = sum(y_train == 0)/sum(y_train == 1)


base_score = sum(y_train == 1)/len(y_train)


max_depth = 6
learning_rate = 0.1
gamma = 0
min_child_weight = 1
subsample = 0.8
colsample_bytree = 0.8
colsample_bylevel = 1
reg_alpha = 0
reg_lambda = 1


clf = xgb.XGBClassifier(max_depth= max_depth, learning_rate= learning_rate,silent=False, objective='binary:logistic', \
                  booster='gbtree', n_jobs=8, nthread=None, gamma=gamma, min_child_weight=min_child_weight, max_delta_step=0, \
                  subsample= subsample, colsample_bytree=colsample_bytree, colsample_bylevel=colsample_bylevel, \
                         reg_alpha= reg_alpha, reg_lambda=1, scale_pos_weight= scale_pos_weight, base_score= base_score)

clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)], eval_metric= F1_eval, verbose=True)

..................
[94]    validation_0-error:0    validation_1-error:0.035088 validation_0-f1_err:0   validation_1-f1_err:0.018634
[95]    validation_0-error:0    validation_1-error:0.035088 validation_0-f1_err:0   validation_1-f1_err:0.018634
[96]    validation_0-error:0    validation_1-error:0.035088 validation_0-f1_err:0   validation_1-f1_err:0.018634
[97]    validation_0-error:0    validation_1-error:0.035088 validation_0-f1_err:0   validation_1-f1_err:0.018634
[98]    validation_0-error:0    validation_1-error:0.035088 validation_0-f1_err:0   validation_1-f1_err:0.018634
[99]    validation_0-error:0    validation_1-error:0.035088 validation_0-f1_err:0   validation_1-f1_err:0.018634

<小时>

clf = xgb.XGBClassifier(max_depth= max_depth, niterations = 1000, learning_rate= learning_rate,silent=False, \
                        objective='binary:logistic', booster='gbtree', n_jobs=8, nthread=None, gamma=gamma,\
                        min_child_weight=min_child_weight, max_delta_step=0, \
                  subsample= subsample, colsample_bytree=colsample_bytree, colsample_bylevel=colsample_bylevel, \
                         reg_alpha= reg_alpha, reg_lambda=1, scale_pos_weight= scale_pos_weight, base_score= base_score)

clf.fit(X_train, y_train, early_stopping_rounds= 25,
eval_set=[(X_train, y_train), (X_test, y_test)], eval_metric= F1_eval, verbose=True)

[0] validation_0-error:0.386813 validation_1-error:0.315789 validation_0-f1_err:0.032609    validation_1-f1_err:0.031847
Multiple eval metrics have been passed: 'validation_1-f1_err' will be used for early stopping.

Will train until validation_1-f1_err hasn't improved in 25 rounds.
[1] validation_0-error:0.131868 validation_1-error:0.078947 validation_0-f1_err:0.016216    validation_1-f1_err:0.031056
[2] validation_0-error:0.048352 validation_1-error:0.052632 validation_0-f1_err:0.012522    validation_1-f1_err:0.037037
[3] validation_0-error:0.032967 validation_1-error:0.04386  validation_0-f1_err:0.008977    validation_1-f1_err:0.031447
[4] validation_0-error:0.01978  validation_1-error:0.04386  validation_0-f1_err:0.010753    validation_1-f1_err:0.031447
[5] validation_0-error:0.015385 validation_1-error:0.035088 validation_0-f1_err:0.008977    validation_1-f1_err:0.025316
[6] validation_0-error:0.013187 validation_1-error:0.04386  validation_0-f1_err:0.010676    validation_1-f1_err:0.025316
[7] validation_0-error:0.017582 validation_1-error:0.04386  validation_0-f1_err:0.010638    validation_1-f1_err:0.018868
[8] validation_0-error:0.013187 validation_1-error:0.04386  validation_0-f1_err:0.008913    validation_1-f1_err:0.025
[9] validation_0-error:0.008791 validation_1-error:0.04386  validation_0-f1_err:0.007143    validation_1-f1_err:0.025
[10]    validation_0-error:0.010989 validation_1-error:0.04386  validation_0-f1_err:0.007143    validation_1-f1_err:0.025
[11]    validation_0-error:0.008791 validation_1-error:0.04386  validation_0-f1_err:0.007143    validation_1-f1_err:0.025
[12]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.007143    validation_1-f1_err:0.025
[13]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.007117    validation_1-f1_err:0.025
[14]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[15]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[16]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[17]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[18]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[19]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[20]    validation_0-error:0.008791 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[21]    validation_0-error:0.006593 validation_1-error:0.052632 validation_0-f1_err:0.005348    validation_1-f1_err:0.018868
[22]    validation_0-error:0.006593 validation_1-error:0.052632 validation_0-f1_err:0.003571    validation_1-f1_err:0.018868
[23]    validation_0-error:0.006593 validation_1-error:0.052632 validation_0-f1_err:0.003571    validation_1-f1_err:0.018868
[24]    validation_0-error:0.006593 validation_1-error:0.052632 validation_0-f1_err:0.003571    validation_1-f1_err:0.018868
[25]    validation_0-error:0.006593 validation_1-error:0.052632 validation_0-f1_err:0.003571    validation_1-f1_err:0.018868
[26]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.003571    validation_1-f1_err:0.018868
[27]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.003584    validation_1-f1_err:0.018868
[28]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.003584    validation_1-f1_err:0.018868
[29]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.003571    validation_1-f1_err:0.018868
[30]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.001789    validation_1-f1_err:0.018868
[31]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.001789    validation_1-f1_err:0.018868
[32]    validation_0-error:0.004396 validation_1-error:0.052632 validation_0-f1_err:0.001789    validation_1-f1_err:0.018868
Stopping. Best iteration:
[7] validation_0-error:0.017582 validation_1-error:0.04386  validation_0-f1_err:0.010638    validation_1-f1_err:0.018868

XGBClassifier(base_score=0.6131868131868132, booster='gbtree',
       colsample_bylevel=1, colsample_bytree=0.8, gamma=0,
       learning_rate=0.1, max_delta_step=0, max_depth=6,
       min_child_weight=1, missing=None, n_estimators=100, n_jobs=8,
       niterations=1000, nthread=None, objective='binary:logistic',
       random_state=0, reg_alpha=0, reg_lambda=1,
       scale_pos_weight=0.6308243727598566, seed=None, silent=False,
       subsample=0.8)

推荐答案

你得到 4 个评估矩阵,因为 xgboost 不知何故向你的 eval_set 添加了另一个评估指标.就个人而言,我使用的是核心 xgboost 而不是 scikit warp up.因此,如果您想了解更多信息,请在文档中阅读.

you get 4 evaluation matrices because somehow xgboost adding another evaluation metric to your eval_set. personally, I'm using the core xgboost and not the scikit warp up. So if you want to learn more, read about it in the documentation.

对于early_stopping,您必须将n_estimators=1000(或您想要的迭代次数)设置为xgb.XGBClassifier 中的参数

for early_stopping, you have to set n_estimators=1000 (or how many iterations you want) as a parameter in xgb.XGBClassifier

并在 clf.fit 中设置 early_stopping_rounds=50(或您想要的任何值).这是文档.

And set early_stopping_rounds=50 (or what ever value you want) in clf.fit. Here's the documentation.

提前停止决定何时需要停止提升算法以避免过度拟合.它是通过评估您在 eval_set 中定义的 tuple (X_test, y_test) 来实现的.如果评估误差在 50 次迭代中没有减少,则 early_stopping 将停止您的提升.

early stopping comes to decide when you need to stop boosting the algorithm to avoid over fitting. it is doing so by evaluating your tuple (X_test, y_test) you defined in eval_set. early_stopping will stop your boosting if the evaluation error hasn't decrease over 50 iterations.

这篇关于带有自定义评估函数的 Python 中 xgboost 的意外行为的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆