在sklearn DecisionTreeClassifier中修剪不必要的叶子 [英] Prune unnecessary leaves in sklearn DecisionTreeClassifier

查看:531
本文介绍了在sklearn DecisionTreeClassifier中修剪不必要的叶子的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用sklearn.tree.DecisionTreeClassifier来构建决策树。使用最佳参数设置,我得到的树上有多余的叶子(请参见下面的示例图片-我不需要概率,因此标有红色的叶子节点是不必要的分割)

I use sklearn.tree.DecisionTreeClassifier to build a decision tree. With the optimal parameter settings, I get a tree that has unnecessary leaves (see example picture below - I do not need probabilities, so the leaf nodes marked with red are a unnecessary split)

是否存在用于修剪这些不必要节点的第三方库?还是代码片段?我可以写一个,但是我真的无法想象我是第一个遇到此问题的人...

Is there any third-party library for pruning these unnecessary nodes? Or a code snippet? I could write one, but I can't really imagine that I am the first person with this problem...

要复制的代码:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS:我尝试了多个关键字搜索,却一无所获-sklearn中是否真的没有一般的后修剪功能?

PS: I have tried multiple keyword searches and am kind of surprised to find nothing - is there really no post-pruning in general in sklearn?

PPS:针对可能出现的重复:建议的问题可能会对我有所帮助当我自己编写修剪算法时,它回答了一个不同的问题-我想摆脱不会改变最终决定的叶子,而另一个问题想要一个最小的分割节点阈值。

PPS: In response to the possible duplicate: While the suggested question might help me when coding the pruning algorithm myself, it answers a different question - I want to get rid of leaves that do not change the final decision, while the other question wants a minimum threshold for splitting nodes.

PPPS:显示的树一个例子来说明我的问题。我知道以下事实:创建树的参数设置不理想。我并不是要优化这棵特定的树,我需要进行修剪后的处理以除去可能需要类概率的叶子,而如果只对最可能的类感兴趣的叶子则无济于事。

PPPS: The tree shown is an example to show my problem. I am aware of the fact that the parameter settings to create the tree are suboptimal. I am not asking about optimizing this specific tree, I need to do post-pruning to get rid of leaves that might be helpful if one needs class probabilities, but are not helpful if one is only interested in the most likely class.

推荐答案

使用ncfirth的链接,我可以在那里修改代码,使其适合我的问题:

Using ncfirth's link, I was able to modify the code there so that it fits to my problem:

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

在DecisionTreeClassifier clf上使用:

Using this on a DecisionTreeClassifier clf:

prune_duplicate_leaves(clf)






编辑:修复了更复杂树的错误


Fixed a bug for more complex trees

这篇关于在sklearn DecisionTreeClassifier中修剪不必要的叶子的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆