scikit学习-决策树中的特征重要性计算 [英] scikit learn - feature importance calculation in decision trees

查看:824
本文介绍了scikit学习-决策树中的特征重要性计算的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我试图了解如何在sci-kit学习中为决策树计算功能重要性。之前已经问过这个问题,但是我无法重现该算法提供的结果。



例如:

 从StringIO导入StringIO 
来自sklearn.datasets的
进口load_iris来自sklearn.tree的
进口DecisionTreeClassifier来自sklearn.tree.export的
来自sklearn.feature_selection进口的export_graphviz
进口的public_info_classif

X = [[1,0,0],[0,0,0],[0,0,1],[0,1,0]]

y = [1,0,1 ,1]

clf = DecisionTreeClassifier()
clf.fit(X,y)

feat_importance = clf.tree_.compute_feature_importances(normalize = False)
print( feat重要度= + str(feat_importance))

out = StringIO()
out = export_graphviz(clf,out_file ='test / tree.dot')

功能重要性:

 成功重要性= [0.25 0.08333333 0.04166667] 

并给出以下决策树:





现在,此



其中G是节点杂质,在这种情况下是基尼杂质。据我了解,这是减少杂质的方法。但是,对于功能1,应该为:







两个公式均提供错误的结果。如何正确计算功能的重要性?

解决方案

我认为功能的重要性取决于实现方式,因此我们需要查看scikit-learn。


功能重要性。越高,功能越重要。特征的重要性计算为该特征带来的标准的(标准化)总缩减。也称为基尼重要性


减少或加权信息获取的定义为:


加权减少杂质的公式如下:



N_t / N *(杂质- N_t_R / N_t *右边杂质
-N_t_L / N_t *左边杂质)



其中N是样本总数,N_t是样本数当前节点上的样本数,N_t_L是左子节点中的样本数,N_t_R是右子节点中的样本数。


http:// scikit -learn.org/stable/modules/generation/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier



由于每个功能在您的情况,功能信息亩



对于X [2]:



feature_importance =(4/4)*(0.375-(0.75 * 0.444))= 0.042



对于X [1]:



feature_importance =(3/4)*(0.444-(2/3 * 0.5))= 0.083



对于X [0]:



feature_importance =(2/4)*(0.5)= 0.25


I'm trying to understand how feature importance is calculated for decision trees in sci-kit learn. This question has been asked before, but I am unable to reproduce the results the algorithm is providing.

For example:

from StringIO import StringIO

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree.export import export_graphviz
from sklearn.feature_selection import mutual_info_classif

X = [[1,0,0], [0,0,0], [0,0,1], [0,1,0]]

y = [1,0,1,1]

clf = DecisionTreeClassifier()
clf.fit(X, y)

feat_importance = clf.tree_.compute_feature_importances(normalize=False)
print("feat importance = " + str(feat_importance))

out = StringIO()
out = export_graphviz(clf, out_file='test/tree.dot')

results in feature importance:

feat importance = [0.25       0.08333333 0.04166667]

and gives the following decision tree:

Now, this answer to a similar question suggests the importance is calculated as

Where G is the node impurity, in this case the gini impurity. This is the impurity reduction as far as I understood it. However, for feature 1 this should be:

This answer suggests the importance is weighted by the probability of reaching the node (which is approximated by the proportion of samples reaching that node). Again, for feature 1 this should be:

Both formulas provide the wrong result. How is the feature importance calculated correctly?

解决方案

I think feature importance depends on the implementation so we need to look at the documentation of scikit-learn.

The feature importances. The higher, the more important the feature. The importance of a feature is computed as the (normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance

That reduction or weighted information gain is defined as :

The weighted impurity decrease equation is the following:

N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)

where N is the total number of samples, N_t is the number of samples at the current node, N_t_L is the number of samples in the left child, and N_t_R is the number of samples in the right child.

http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

Since each feature is used once in your case, feature information must be equal to equation above.

For X[2] :

feature_importance = (4 / 4) * (0.375 - (0.75 * 0.444)) = 0.042

For X[1] :

feature_importance = (3 / 4) * (0.444 - (2/3 * 0.5)) = 0.083

For X[0] :

feature_importance = (2 / 4) * (0.5) = 0.25

这篇关于scikit学习-决策树中的特征重要性计算的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆