如何显示测试样本的决策树路径? [英] How to display the path of a Decision Tree for test samples?

查看:534
本文介绍了如何显示测试样本的决策树路径?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用scikit中的

  import pydotplus 
从sklearn.datasets导入load_iris来自sklearn导入树的


clf = tree.DecisionTreeClassifier(random_state = 42)
iris = load_iris()

clf = clf.fit(iris .data,iris.target)

dot_data = tr ee.export_graphviz(clf,out_file = None,
feature_names = iris.feature_names,
class_names = iris.target_names,
fill = True,四舍五入= True,
special_characters = True)
graph = pydotplus.graph_from_dot_data(dot_data)

#清空所有节点,即将颜色设置为白色,并将样本数设置为零
为graph.get_node_list()中的节点: b $ b如果node.get_attributes()。get('label')为None:
继续
如果node.get_attributes()['label']中的'samples =',则为
标签= node.get_attributes()['label']。split('< br /&';')
for i,枚举(labels)中的标签:
if label.startswith('samples =' ):
labels [i] ='样本= 0'
node.set('label','< br />'。join(labels))
node.set_fillcolor( 'white')

个样本= iris.data [129:130]
Decision_paths = clf.decision_path(samples)

for Decision_paths中的Decision_path:
for n,enumerate中的node_value(decision_path.toarray()[0]):
(如果node_value == 0:
继续
节点= graph.get_node(str(n))[0]
node.set_fillcolor('green')
标签= node.get_attributes()['label' ] .split('< br /&';)
for i,枚举(标签)中的标签:
if label.startswith('samples ='):
标签[i] ='samples = {}'。format(int(label.split('=')[1])+1)

node.set('label','< br /> '.join(标签))

filename ='tree.png'
graph.write_png(文件名)


I'm using DecisionTreeClassifier from scikit-learn to classify some multiclass data. I found many posts describing how to display the decision tree path, like here, here, and here. However, all of them describe how to display the tree for the trained data. It makes sense, because export_graphviz only requires a fitted model.

My question is how do I visualize the tree on the test samples (preferably by export_graphviz). I.e. after fitting the model with clf.fit(X[train], y[train]), and then predicting the results for the test data by clf.predict(X[test]), I want to visualize the decision path used for predicting the samples X[test]. Is there a way to do that?

Edit:

I see that the path can be printed using decision_path. If there's a way to get a DOT output as of export_graphviz to display it, that would be great.

解决方案

In order to get the path which is taken for a particular sample in a decision tree you could use decision_path. It returns a sparse matrix with the decision paths for the provided samples.

Those decision paths can then be used to color/label the tree generated via pydot. This requires overwriting the color and the label (which results in a bit of ugly code).

Notes

  • decision_path can take samples from the training set or new values
  • you can go wild with the colors and change the color according to the number of samples or whatever other visualization might be needed

Example

In the example below a visited node is colored in green, all other nodes are white.

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)

# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
    if node.get_attributes().get('label') is None:
        continue
    if 'samples = ' in node.get_attributes()['label']:
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = 0'
        node.set('label', '<br/>'.join(labels))
        node.set_fillcolor('white')

samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)

for decision_path in decision_paths:
    for n, node_value in enumerate(decision_path.toarray()[0]):
        if node_value == 0:
            continue
        node = graph.get_node(str(n))[0]            
        node.set_fillcolor('green')
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)

        node.set('label', '<br/>'.join(labels))

filename = 'tree.png'
graph.write_png(filename)

这篇关于如何显示测试样本的决策树路径?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆