使用class_names的带有graphviz的树的节点的颜色 [英] Color of the node of tree with graphviz using class_names

查看:526
本文介绍了使用class_names的带有graphviz的树的节点的颜色的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

扩展先前的问题: 更改使用导出graphviz创建的决策树图的颜色

我如何基于优势类(虹膜种类)而不是二进制区分为树的节点着色?这应该需要结合描述类的字符串iris.target_names和有关类的iris.target的组合.

How would I color the nodes of the tree bases on the dominant class (species of iris), instead of a binary distinction? This should require a combination of the iris.target_names, the string describing the class, and iris.target, the class.

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

推荐答案

示例中的代码看起来很熟悉,因此很容易修改:)

The code from the example looks so familiar and is therefore easy to modify :)

对于每个节点,Graphviz告诉我们每个组有多少个样本,即是混合种群还是由树决定.我们可以提取此信息并用于获取颜色.

For each node Graphviz tells us how many samples from each group we have, i.e. if it is a mixed population or the tree came to a decision. We can extract this info and use to get a color.

values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]

或者,您可以将GraphViz节点映射回sklearn节点:

Alternatively you can map the GraphViz nodes back to the sklearn nodes:

values = clf.tree_.value[int(node.get_name())][0]

我们只有3个类别,所以每个类别都有其自己的颜色(红色,绿色,蓝色),混合种群根据它们的分布而获得混合颜色.

We only have 3 classes, so each one gets its own color (red, green, blue), mixed populations get mixed colors according to their distribution.

values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])

我们现在可以很好地看到分色,绿色越多,我们得到的第二类越多,蓝色和第三类也一样.

We can now see the separation nicely, the greener it gets the more of the 2nd class we have, same for blue and the 3rd class.

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
        values = [int(255 * v / sum(values)) for v in values]
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
        node.set_fillcolor(color)

graph.write_png('colored_tree.png')


3个以上类的通用解决方案,仅对最终节点着色.


A general solution for more than 3 classes which colors only the final nodes .

colors =  ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')

for node in nodes:
    if node.get_name() not in ('node', 'edge'):
        values = clf.tree_.value[int(node.get_name())][0]
        #color only nodes where only one class is present
        if max(values) == sum(values):    
            node.set_fillcolor(colors[numpy.argmax(values)])
        #mixed nodes get the default color
        else:
            node.set_fillcolor(colors[-1])

这篇关于使用class_names的带有graphviz的树的节点的颜色的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆