ConfusionMatrixDisplay (Scikit-Learn) 绘图标签超出范围 [英] ConfusionMatrixDisplay (Scikit-Learn) plot labels out of range

查看:218
本文介绍了ConfusionMatrixDisplay (Scikit-Learn) 绘图标签超出范围的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

以下代码绘制了一个混淆矩阵:

from sklearn.metrics import ConfusionMatrixDisplay混淆矩阵=混淆矩阵(y_true,y_pred)target_names = [aaaaa"、bbbbbb"、ccccccc"、dddddddd"、eeeeeeeeee"、ffffffff"、gggggggggg"]disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)disp.plot(cmap=plt.cm.Blues,xticks_rotation=45)plt.savefig("conf.png")

这个情节有两个问题.

  1. y 轴标签被切断(真实标签).x 标签也被剪掉了.
  2. 名称对于 x 轴来说太长了.

为了解决第一个问题,我尝试使用 poof(bbox_inches='tight'),遗憾的是它不适用于 sklearn.在第二种情况下,我为

The following code plots a confusion matrix:

from sklearn.metrics import ConfusionMatrixDisplay

confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.savefig("conf.png")

There are two problems with this plot.

  1. The y-axis label is cut off (True Label). The x label is cut off too.
  2. The names are to long for the x-axis.

To solve the first problem I tried to use poof(bbox_inches='tight') which is unfortunately not available for sklearn. In the second case I tried the following solution for 2. which lead to a completely distorted plot.

All in all I'm struggeling with both problems.

解决方案

I think the easiest way would be to switch into tight_layout and add pad_inches= something.

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from numpy.random import default_rng

rand = default_rng()
y_true = rand.integers(low=0, high=7, size=500)
y_pred = rand.integers(low=0, high=7, size=500)


confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)

plt.tight_layout()
plt.savefig("conf.png", pad_inches=5)

Result:

这篇关于ConfusionMatrixDisplay (Scikit-Learn) 绘图标签超出范围的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆