在scikit-learn中的数据集上绘制决策树 [英] Plot decision tree over dataset in scikit-learn

查看:67
本文介绍了在scikit-learn中的数据集上绘制决策树的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在尝试将测试随机分为测试集和训练集,并在5层深的决策树上进行训练,并绘制决策树.

I've been trying to divide randomly into test and train sets my dataset and train on a 5 deep decision tree and plot the decision tree.

P.s.我不允许使用熊猫.

P.s. I'm not allowed to use pandas to do so.

这是我尝试做的事情:

import numpy
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn import tree
from sklearn.model_selection import train_test_split
filename = 'diabetes.csv'
raw_data = open(filename, 'rt')
data = numpy.loadtxt(raw_data, delimiter=",", skiprows=1)
print(data.shape)

X = data[:,0:8] #identify columns as data sets
Y = data[:, 9] #identfy last column as target
print(X)
print(Y)
X_train, X_test, Y_train, Y_test = train_test_split(
X, Y, test_size=0.25)
treeClassifier = DecisionTreeClassifier(max_depth=5)
treeClassifier.fit(X_train, Y_train)
with open("treeClassifier.txt", "w") as f:
 f = tree.export_graphviz(treeClassifier, out_file=f)

我的输出是:

(768, 10)
[[  6.    148.     72.    ...  33.6     0.627  50.   ]
[  1.     85.     66.    ...  26.6     0.351  31.   ]
[  8.    183.     64.    ...  23.3     0.672  32.   ]
 ...
[  5.    121.     72.    ...  26.2     0.245  30.   ]
[  1.    126.     60.    ...  30.1     0.349  47.   ]
[  1.     93.     70.    ...  30.4     0.315  23.   ]]
[1. 0. 1. 0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 1.
 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 0. 0.
 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 0.
 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1. 0. 0. 0.
 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0.
 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0.
 1. 1. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1.
 1. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 0.
 0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 0.
 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1. 0. 0. 1.
 0. 0. 0. 1. 1. 1. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0.
 1. 0. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0.
 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 1. 0. 1.
 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 1. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 1. 0. 1. 0. 1. 0.
 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 0.
 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.
 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.
 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0.
 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0.
 0. 0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 1.
 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 1. 0. 0. 0. 0.
 0. 0. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 1.
 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 1.
 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1.
 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0.]

以下是我希望生成的树看起来像的一个示例:

Here is an example of what I want the resulting tree to look like:

我遇到的问题是在我的树中,我没有获得'class = 0 \ class = 1'属性.我认为问题可能出在Y = data[:, 9]部分,第9列对它是0还是1进行了分类-这是class属性,但是我看不到有任何方法可以对其进行更改以使其出现在树中;也许tree.export_graphviz函数中的某些内容?我是否缺少参数?任何帮助将不胜感激.

The problem I'm having is that in my tree, I don't get the 'class=0\ class=1' attribute. I thought the problem might be in the Y = data[:, 9] part, the 9th column classifies if it's a 0 or a 1 -- this is the class attribute, but I don't see any way to change it to make it appear in the tree; maybe something in the tree.export_graphviz function? Am I missing a parameter? Any help would be appreciated.

推荐答案

如果替换

tree.export_graphviz(treeClassifier, out_file=f)

使用

tree.export_graphviz(treeClassifier, class_names=['0', '1'], out_file=f)

你应该很好.

例如,

import graphviz
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import train_test_split

np.random.seed(42)
X = np.random.random((100, 8))
Y = np.random.randint(2, size=100)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.25)
tree_classifier = DecisionTreeClassifier(max_depth=5)
tree_classifier.fit(X_train, Y_train)

dot_data = tree.export_graphviz(tree_classifier, class_names=['0', '1'], out_file=None)
graph = graphviz.Source(dot_data)
graph

要使其看起来更像您引用的示例,可以使用

To make it look even more like the example you refer to, you can use

tree.export_graphviz(treeClassifier, class_names=['0', '1'],
                     filled=True, rounded=True, out_file=f)

这篇关于在scikit-learn中的数据集上绘制决策树的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆