如何将sklearn决策树规则提取到 pandas 布尔条件? [英] How to extract sklearn decision tree rules to pandas boolean conditions?

查看:298
本文介绍了如何将sklearn决策树规则提取到 pandas 布尔条件?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这样的帖子太多关于如何提取sklearn决策树规则,但我找不到有关使用熊猫的任何信息.

There are so many posts like this about how to extract sklearn decision tree rules but I could not find any about using pandas.

例如,以此数据和模型为例,如下所示

Take this data and model for example, as below

# Create Decision Tree classifer object
clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)

# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)

结果:

预期:

关于此示例,有8条规则.

There're 8 rules about this example.

从左到右,请注意数据帧为df

From left to right,notice that dataframe is df

r1 = (df['glucose']<=127.5) & (df['bmi']<=26.45) & (df['bmi']<=9.1)
……
r8 =  (df['glucose']>127.5) & (df['bmi']>28.15) & (df['glucose']>158.5)

我不是提取sklearn决策树规则的大师.获取大熊猫布尔条件将有助于我为每个规则计算样本和其他指标.因此,我想将每个规则提取到熊猫的布尔条件.

I'm not a master of extracting sklearn decision tree rules. Getting the pandas boolean conditions will help me calculate samples and other metrics for each rule. So I want to extract each rule to a pandas boolean condition.

推荐答案

首先让我们使用scikit 文档,以获取有关构造的树的信息:

First of all let's use the scikit documentation on decision tree structure to get information about the tree that was constructed :

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold

然后,我们定义两个递归函数.第一个将找到从树的根开始创建特定节点的路径(在本例中为所有叶子).第二个将使用其创建路径编写用于创建节点的特定规则:

We then define two recursive functions. The first one will find the path from the tree's root to create a specific node (all the leaves in our case). The second one will write the specific rules used to create a node using its creation path :

def find_path(node_numb, path, x):
        path.append(node_numb)
        if node_numb == x:
            return True
        left = False
        right = False
        if (children_left[node_numb] !=-1):
            left = find_path(children_left[node_numb], path, x)
        if (children_right[node_numb] !=-1):
            right = find_path(children_right[node_numb], path, x)
        if left or right :
            return True
        path.remove(node_numb)
        return False


def get_rule(path, column_names):
    mask = ''
    for index, node in enumerate(path):
        #We check if we are not in the leaf
        if index!=len(path)-1:
            # Do we go under or over the threshold ?
            if (children_left[node] == path[index+1]):
                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])
            else:
                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])
    # We insert the & at the right places
    mask = mask.replace("\t", "&", mask.count("\t") - 1)
    mask = mask.replace("\t", "")
    return mask

最后,我们使用这两个函数来首先存储每个叶子的创建路径.然后存储用于创建每个叶子的规则:

Finally, we use those two functions to first store the creation path of each leaf. And then to store the rules used to create each leaf :

# Leaves
leave_id = clf.apply(X_test)

paths ={}
for leaf in np.unique(leave_id):
    path_leaf = []
    find_path(0, path_leaf, leaf)
    paths[leaf] = np.unique(np.sort(path_leaf))

rules = {}
for key in paths:
    rules[key] = get_rule(paths[key], pima.columns)


根据您提供的数据,输出为:


With the data you gave the output is :

rules =
{3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ",
 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727)  ",
 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ",
 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5)  ",
 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5)  ",
 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ",
 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ",
 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

由于规则是字符串,因此不能使用df[rules[3]]直接调用它们,因此必须像df[eval(rules[3])]

Since the rules are strings, you can't directly call them using df[rules[3]], you have to use the eval function like so df[eval(rules[3])]

这篇关于如何将sklearn决策树规则提取到 pandas 布尔条件?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆