sklearn 绘图混淆矩阵与标签 [英] sklearn plot confusion matrix with labels

查看:21
本文介绍了sklearn 绘图混淆矩阵与标签的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想绘制一个混淆矩阵来可视化分类器的性能,但它只显示标签的数量,而不是标签本身:

I want to plot a confusion matrix to visualize the classifer's performance, but it shows only the numbers of the labels, not the labels themselves:

from sklearn.metrics import confusion_matrix
import pylab as pl
y_test=['business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business']

pred=array(['health', 'business', 'business', 'business', 'business',
       'business', 'health', 'health', 'business', 'business', 'business',
       'business', 'business', 'business', 'business', 'business',
       'health', 'health', 'business', 'health'], 
      dtype='|S8')

cm = confusion_matrix(y_test, pred)
pl.matshow(cm)
pl.title('Confusion matrix of the classifier')
pl.colorbar()
pl.show()

如何将标签(健康、业务等)添加到混淆矩阵中?

How can I add the labels (health, business..etc) to the confusion matrix?

推荐答案

正如 这个问题中所暗示的,您必须打开"低级艺术家 API,通过存储由您调用的 matplotlib 函数传递的图形和轴对象(下面的 figaxcax 变量).然后,您可以使用 set_xticklabels/set_yticklabels 替换默认的 x 轴和 y 轴刻度:

As hinted in this question, you have to "open" the lower-level artist API, by storing the figure and axis objects passed by the matplotlib functions you call (the fig, ax and cax variables below). You can then replace the default x- and y-axis ticks using set_xticklabels/set_yticklabels:

from sklearn.metrics import confusion_matrix

labels = ['business', 'health']
cm = confusion_matrix(y_test, pred, labels)
print(cm)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

请注意,我将 labels 列表传递给了 confusion_matrix 函数,以确保它正确排序,与刻度匹配.

Note that I passed the labels list to the confusion_matrix function to make sure it's properly sorted, matching the ticks.

结果如下图:

这篇关于sklearn 绘图混淆矩阵与标签的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆