使用class_names的带有graphviz的树的节点的颜色 [英] Color of the node of tree with graphviz using class_names
问题描述
扩展先前的问题: 更改使用导出graphviz创建的决策树图的颜色
我如何基于优势类(虹膜种类)而不是二进制区分为树的节点着色?这应该需要结合描述类的字符串iris.target_names和有关类的iris.target的组合.
How would I color the nodes of the tree bases on the dominant class (species of iris), instead of a binary distinction? This should require a combination of the iris.target_names, the string describing the class, and iris.target, the class.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()
colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)
for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))
for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])
graph.write_png('tree.png')
推荐答案
示例中的代码看起来很熟悉,因此很容易修改:)
The code from the example looks so familiar and is therefore easy to modify :)
对于每个节点,Graphviz
告诉我们每个组有多少个样本,即是混合种群还是由树决定.我们可以提取此信息并用于获取颜色.
For each node Graphviz
tells us how many samples from each group we have, i.e. if it is a mixed population or the tree came to a decision. We can extract this info and use to get a color.
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
或者,您可以将GraphViz
节点映射回sklearn
节点:
Alternatively you can map the GraphViz
nodes back to the sklearn
nodes:
values = clf.tree_.value[int(node.get_name())][0]
我们只有3个类别,所以每个类别都有其自己的颜色(红色,绿色,蓝色),混合种群根据它们的分布而获得混合颜色.
We only have 3 classes, so each one gets its own color (red, green, blue), mixed populations get mixed colors according to their distribution.
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
我们现在可以很好地看到分色,绿色越多,我们得到的第二类越多,蓝色和第三类也一样.
We can now see the separation nicely, the greener it gets the more of the 2nd class we have, same for blue and the 3rd class.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf,
feature_names=iris.feature_names,
out_file=None,
filled=True,
rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
node.set_fillcolor(color)
graph.write_png('colored_tree.png')
3个以上类的通用解决方案,仅对最终节点着色.
A general solution for more than 3 classes which colors only the final nodes .
colors = ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')
for node in nodes:
if node.get_name() not in ('node', 'edge'):
values = clf.tree_.value[int(node.get_name())][0]
#color only nodes where only one class is present
if max(values) == sum(values):
node.set_fillcolor(colors[numpy.argmax(values)])
#mixed nodes get the default color
else:
node.set_fillcolor(colors[-1])
这篇关于使用class_names的带有graphviz的树的节点的颜色的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!