带标签的sklearn图混淆矩阵 [英] sklearn plot confusion matrix with labels
问题描述
我想绘制一个混淆矩阵以可视化分类器的性能,但它仅显示标签的数字,而不显示标签本身:
I want to plot a confusion matrix to visualize the classifer's performance, but it shows only the numbers of the labels, not the labels themselves:
from sklearn.metrics import confusion_matrix
import pylab as pl
y_test=['business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business']
pred=array(['health', 'business', 'business', 'business', 'business',
'business', 'health', 'health', 'business', 'business', 'business',
'business', 'business', 'business', 'business', 'business',
'health', 'health', 'business', 'health'],
dtype='|S8')
cm = confusion_matrix(y_test, pred)
pl.matshow(cm)
pl.title('Confusion matrix of the classifier')
pl.colorbar()
pl.show()
如何将标签(健康,业务等)添加到混乱矩阵中?
How can I add the labels (health, business..etc) to the confusion matrix?
推荐答案
如此问题中所述,则必须通过存储以下内容来打开" 下级艺术家API .您调用的matplotlib函数传递的图形和轴对象(下面的fig
,ax
和cax
变量).然后,您可以使用set_xticklabels
/set_yticklabels
替换默认的x轴和y轴刻度:
As hinted in this question, you have to "open" the lower-level artist API, by storing the figure and axis objects passed by the matplotlib functions you call (the fig
, ax
and cax
variables below). You can then replace the default x- and y-axis ticks using set_xticklabels
/set_yticklabels
:
from sklearn.metrics import confusion_matrix
labels = ['business', 'health']
cm = confusion_matrix(y_test, pred, labels)
print(cm)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
请注意,我将labels
列表传递给了confusion_matrix
函数,以确保其正确排序,与刻度线匹配.
Note that I passed the labels
list to the confusion_matrix
function to make sure it's properly sorted, matching the ticks.
这将导致下图:
这篇关于带标签的sklearn图混淆矩阵的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!