使用混淆矩阵了解多标签分类器 [英] Understanding multi-label classifier using confusion matrix

查看:137
本文介绍了使用混淆矩阵了解多标签分类器的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有12个类的多标签分类问题。我正在使用 Tensorflow slim 使用在 ImageNet 。以下是每个班级参加培训的人数所占的百分比验证

I have a multi-label classification problem with 12 classes. I'm using slim of Tensorflow to train the model using the models pretrained on ImageNet. Here are the percentages of presence of each class in the training & validation

            Training     Validation
  class0      44.4          25
  class1      55.6          50
  class2      50            25
  class3      55.6          50
  class4      44.4          50
  class5      50            75
  class6      50            75
  class7      55.6          50
  class8      88.9          50
  class9     88.9           50
  class10     50            25
  class11     72.2          25

问题是模型没有收敛并且验证集上 ROC 曲线( Az )的下限很差,例如:

The problem is that the model did not converge and the are under of the ROC curve (Az) on the validation set was poor, something like:

               Az 
  class0      0.99
  class1      0.44
  class2      0.96  
  class3      0.9
  class4      0.99
  class5      0.01
  class6      0.52
  class7      0.65
  class8      0.97
  class9     0.82
  class10     0.09
  class11     0.5
  Average     0.65

我不知道它为什么起作用对某些班级有好处,而对其他班级则无益。我决定深入研究细节,以了解神经网络正在学习什么。我知道混淆矩阵仅适用于二进制或多类分类。因此,为了能够绘制出来,我不得不将问题转换成多类分类对。即使使用 sigmoid 对模型进行了训练,可以为每个类别提供预测,但对于下面的混淆矩阵中的每个单元格,我仍显示概率的平均值(通过将 sigmoid 函数应用于张量流的预测而获得)图像的图像,其中矩阵行中的类存在而列中的类不存在。这已应用于验证集图像。通过这种方式,我认为我可以获得有关模型正在学习的更多详细信息。我只是圈出对角线元素以用于显示。

I had no clue why it works good for some classes and it does not for the others. I decided to dig into the details to see what the neural network is learning. I know that confusion matrix is only applicable on binary or multi-class classification. Thus, to be able to draw it, I had to convert the problem into pairs of multi-class classification. Even though the model was trained using sigmoid to provide a prediction for each class, for each every single cell in the confusion matrix below, I'm showing the average of the probabilities (got by applying sigmoid function on the predictions of tensorflow) of the images where the class in the row of the matrix is present and the class in column is not present. This was applied on the validation set images. This way I thought I can get more details about what the model is learning. I just circled the diagonal elements for display purposes.

我的解释是:


  1. Class 0&当它们存在时,检测到4个存在,在不存在时不存在。

  2. 2、6和6类。始终会检测到7个不存在。这不是我想要的。

  3. 第3、8和8类;始终会检测到9个。这不是我想要的。可以将其应用于类别11。

  4. 如果类别5不存在,则检测为存在,而当类别5存在时,则检测为不存在。

  5. 3级和2级; 10:我认为我们不能为这两个类提取太多信息。

  1. Classes 0 & 4 are detected present when they are present and not present where they are not. This means these classes are well detected.
  2. Classes 2, 6 & 7 are always detected as not present. This is not what I'm looking for.
  3. Classes 3, 8 & 9 are always detected as present. This is not what I'm looking for. This can be applied to the class 11.
  4. Class 5 is detected present when it is not present and detected as not present when it is present. It is inversely detected.
  5. Classes 3 & 10: I don't think we can extract too much information for these 2 classes.

我的问题是解释。我不确定问题出在哪里,也不确定在数据集中是否会产生这样的结果。我也想知道是否有一些指标可以解决多标签分类问题?你能和我分享你对这种混淆矩阵的解释吗?以及接下来要看什么/在哪里?

My problem is the interpretation.. I'm not sure where the problem is and I'm not sure if there is a bias in the dataset that produce such results. I'm also wondering if there are some metrics that can help in multi-label classification problems? Can u please share with me your interpretation for such confusion matrix? and what/where to look next? some suggestions for other metrics would be great.

谢谢。

编辑:

我将问题转换为多类分类,因此对于每对类(例如0,1),计算概率(类0,类1),表示为 p(0,1)
我对存在工具0而没有工具1的图像的工具1进行预测,并将其转换为通过应用S型函数计算概率,然后显示这些概率的均值。对于 p(1,0),我使用存在工具1而没有工具0的图像对工具0进行相同的操作,但是现在。对于 p(0,0),我使用存在工具0的所有图像。考虑到上图中的 p(0,4),N / A表示不存在工具0和工具4不存在的图像。

I converted the problem to multi-class classification so for each pair of classes (e.g. 0,1) to compute the probability(class 0, class 1), denoted as p(0,1): I take the predictions of tool 1 of the images where tool 0 is present and tool 1 is not present and I convert them to probabilities by applying the sigmoid function, then I show the mean of those probabilities. For p(1, 0), I do the same for but now for the tool 0 using the images where tool 1 is present and tool 0 is not present. For p(0, 0), I use all the images where tool 0 is present. Considering p(0,4) in the image above, N/A means there are no images where tool 0 is present and tool 4 is not present.

以下是2个子集的图像数量:

Here are the number of images for the 2 subsets:


  1. 169320张图像用于训练

  2. 37440图像进行验证

这是在训练集上计算出的混淆矩阵(计算方法与上述验证集相同),但是这次颜色代码是用于计算每种概率的图像数:

Here is the confusion matrix computed on the training set (computed the same way as on the validation set described previously) but this time the color code is the number of images used to compute each probability:

编辑:
对于数据增强,我对网络中的每个输入图像进行随机平移,旋转和缩放。此外,以下是有关工具的一些信息:

EDITED: For data augmentation, I do a random translation, rotation and scaling for each input image to the network. Moreover, here are some information about the tools:

class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.

已编辑:
这是建议代码的输出以下为训练集:

EDITED: Here is the output of the code proposed below for the training set:

Avg. num labels per image =  6.892700212615167
On average, images with label  0  also have  6.365296803652968  other labels.
On average, images with label  1  also have  6.601033718926901  other labels.
On average, images with label  2  also have  6.758548914659531  other labels.
On average, images with label  3  also have  6.131520940484937  other labels.
On average, images with label  4  also have  6.219187208527648  other labels.
On average, images with label  5  also have  6.536933407946279  other labels.
On average, images with label  6  also have  6.533908387864367  other labels.
On average, images with label  7  also have  6.485973817793214  other labels.
On average, images with label  8  also have  6.1241642788920725  other labels.
On average, images with label  9  also have  5.94092288040875  other labels.
On average, images with label  10  also have  6.983303518187239  other labels.
On average, images with label  11  also have  6.1974066621953945  other labels.

对于验证集:

Avg. num labels per image =  6.001282051282051
On average, images with label  0  also have  6.0  other labels.
On average, images with label  1  also have  3.987080103359173  other labels.
On average, images with label  2  also have  6.0  other labels.
On average, images with label  3  also have  5.507731958762887  other labels.
On average, images with label  4  also have  5.506459948320414  other labels.
On average, images with label  5  also have  5.00169779286927  other labels.
On average, images with label  6  also have  5.6729452054794525  other labels.
On average, images with label  7  also have  6.0  other labels.
On average, images with label  8  also have  6.0  other labels.
On average, images with label  9  also have  5.506459948320414  other labels.
On average, images with label  10  also have  3.0  other labels.
On average, images with label  11  also have  4.666095890410959  other labels.

评论:
我认为这不仅与分布之间的差异,因为如果模型能够很好地概括10类(意味着在训练过程中像0类一样正确识别了对象),则验证集的准确性就足够好了。我的意思是,问题本身就在于训练集,以及如何构建问题,而不仅仅是两个分布之间的差异。可能是:类或对象的存在频率非常相似(例如,类10的情况与类9非常相似),或者数据集或细对象内部存在偏差(可能代表输入中像素的1%或2%)像第2类一样的图片。我并不是说问题是其中之一,而是我想指出的是,我认为这不仅仅是两个分布之间的差异。

Comments: I think it is not only related to the difference between distributions because if the model was able to generalize well the class 10 (meaning the object was recognized properly during the training process like the class 0), the accuracy on the validation set would be good enough. I mean that the problem stands in the training set per se and in how it was built more than the difference between both distributions. It can be: frequency of presence of the class or objects resemble strongly (as in the case of the class 10 which strongly resembles to class 9) or bias inside the dataset or thin objects (representing maybe 1 or 2% of pixels in the input image like class 2). I'm not saying that the problem is one of them but I just wanted to point out that I think it's more than difference betwen both distributions.

推荐答案

输出校准



我认为首先要意识到的一件事是神经网络的输出可能未得到很好的校准。我的意思是说,它提供给不同实例的输出可能会产生良好的排名(带有标签L的图像比没有标签L的图像具有更高的得分),但是这些得分不能始终可靠地解释为概率(对于没有标签的实例,它可能会给出很高的分数,例如 0.9 ,而给出的分数甚至会更高,例如 0.99 ,以带有标签的实例)。我想这是否可能取决于您选择的损失函数。

Output Calibration

One thing that I think is important to realise at first is that the outputs of a neural network may be poorly calibrated. What I mean by that is, the outputs it gives to different instances may result in a good ranking (images with label L tend to have higher scores for that label than images without label L), but these scores cannot always reliably be interpreted as probabilities (it may give very high scores, like 0.9, to instances without the label, and just give even higher scores, like 0.99, to instances with the label). I suppose whether or not this may happen depends, among other things, on your chosen loss function.

有关此的更多信息,请参见: https://arxiv.org/abs/1706.04599

For more info on this, see for example: https://arxiv.org/abs/1706.04599

类别0: AUC(曲线下面积)= 0.99 。多数民众赞成在一个很好的分数。混淆矩阵中的第0列看起来也不错,所以这里没有错。

Class 0: AUC (area under curve) = 0.99. Thats a very good score. Column 0 in your confusion matrix also looks fine, so nothing wrong here.

第1类: AUC = 0.44。如果我没记错的话,那真是太糟糕了,低于0.5,这意味着您最好还是故意做出与网络对该标签的预测相反的

Class 1: AUC = 0.44. Thats quite terrible, lower than 0.5, if I'm not mistaken that pretty much means you're better off deliberately doing the opposite of what your network predicts for this label.

看混淆矩阵中的第1列,它到处都有几乎相同的分数。对我来说,这表明网络并没有学到很多关于此类的知识,而只是根据训练集中包含该标签的图像的百分比(55.6%)来猜测。由于此百分比在验证集中下降到50%,因此该策略的确意味着它的效果要比随机效果稍差。第1行仍然是该列中所有行中数量最多的,因此它似乎至少学到了一点点,但学到的并不多。

Looking at column 1 in your confusion matrix, it has pretty much the same scores everywhere. To me, this indicates that the network did not manage to learn a lot about this class, and pretty much just "guesses" according to the percentage of images that contained this label in training set (55.6%). Since this percentage dropped down to 50% in validation set, this strategy indeed means that it'll do slightly worse than random. Row 1 still has the highest number of all rows in this column though, so it appears to have learned at least a tiny little bit, but not much.

2级: AUC = 0.96。这是非常好的。

Class 2: AUC = 0.96. Thats very good.

您对本节课的解释是,根据整列的浅色阴影,总是可以预测该类不存在。我不认为这种解释是正确的。看看它在对角线上的得分如何> 0,而在列中其他任何地方都只有0。它在该行中的得分可能较低,但很容易与同一列中的其他行分开。您可能只需要设置阈值,以选择该标签是否相对较低。我怀疑这是由于上面提到的校准问题。

Your interpretation for this class was that it's always predicted as not being present, based on the light shading of the entire column. I dont think that interpretation is correct though. See how it has a score >0 on the diagonal, and just 0s everywhere else in the column. It may have a relatively low score in that row, but it's easily separable from the other rows in the same column. You'll probably just have to set your threshold for choosing whether or not that label is present relatively low. I suspect this is due to the calibration thing mentioned above.

这也是为什么AUC实际上非常好的原因;可以选择一个阈值,以使分数高于阈值的大多数实例正确地带有标签,而低于阈值的大多数实例正确地带有标签。但是该阈值可能不是0.5,如果您假设校准良好,则可能是该阈值。绘制此特定标签的ROC曲线可以帮助您准确确定阈值应在哪里。

This is also why the AUC is in fact very good; it is possible to select a threshold such that most instances with scores above the threshold correctly have the label, and most instances below it correctly do not. That threshold may not be 0.5 though, which is the threshold you may expect if you assume good calibration. Plotting the ROC curve for this specific label may help you decide exactly where the threshold should be.

第3类: AUC = 0.9,非常好

Class 3: AUC = 0.9, quite good.

您将其解释为始终被检测为存在,并且混淆矩阵的确在列中有很多高数字,但是AUC很好并且单元格对角线上的确实具有足够高的值,可以很容易地与其他值分离。我怀疑这与第2类类似(只是四处翻转,到处都有较高的预测,因此正确决策需要较高的阈值)。

You interpreted it as always being detected as present, and the confusion matrix does indeed have a lot of high numbers in the column, but the AUC is good and the cell on the diagonal does have a sufficiently high value that it may be easily separable from the others. I suspect this is a similar case to Class 2 (just flipped around, high predictions everywhere and therefore a high threshold required for correct decisions).

如果您希望能够为了确定选定的阈值是否确实可以正确地将大多数阳性(类别3的实例)与大多数阴性样本(不具有3类的实例)分开,您需要根据标签的预测分数对所有实例进行排序3,然后遍历整个列表,并在每对连续的条目之间计算如果您决定将阈值放置在此处并选择最佳阈值,则将获得超过验证集的准确性。

If you want to be able to tell for sure whether a well-selected threshold can indeed correctly split most "positives" (instances with class 3) from most "negatives" (instances without class 3), you'll want to sort all instances according to predicted score for label 3, then go through the entire list and between every pair of consecutive entries compute the accuracy over validation set that you would get if you decided to place your threshold right there, and select the best threshold.

第4类:与第0类相同。

第5类: AUC = 0.01,显然很糟糕。也同意您对混淆矩阵的解释。很难确定为什么它在这里的表现如此差劲。也许这是一种很难识别的物体?可能还会出现一些过拟合现象(从第二个矩阵的列中判断,训练数据为0误报,尽管这种情况也会发生在其他类别上)。

Class 5: AUC = 0.01, obviously terrible. Also agree with your interpretation of confusion matrix. It's difficult to tell for sure why it's performing so poorly here. Maybe it is a difficult kind of object to recognize? There's probably also some overfitting going on (0 False Positives in training data judging from the column in your second matrix, though there are also other classes where this happens).

从训练到验证数据的增加,标签5图片的比例可能也无济于事。这意味着网络在训练过程中在此标签上表现良好的重要性比在验证过程中重要。

It probably also doesn't help that the proportion of label 5 images has increased going from training to validation data. This means that it was less important for the network to perform well on this label during training than it is during validation.

第6类: = 0.52,只比随机数好一点。

Class 6: AUC = 0.52, only slightly better than random.

根据第一个矩阵的第6列判断,这实际上可能与第2类类似。如果我们也采用AUC考虑到这一点,它看起来也没有学会很好地对实例进行排名。类似于第5类,只是没有那么糟糕。同样,培训和验证的分配也大不相同。

Judging by column 6 in the first matrix, this actually could have been a similar case to class 2. If we also take AUC into account though, it looks it doesn't learn to rank instances very well either. Similar to class 5, just not as bad. Also, again, training and validation distribution quite different.

第7类: AUC = 0.65,相当平均。例如,显然不如第2类好,但也不如仅从矩阵中解释的那样差。

Class 7: AUC = 0.65, rather average. Obviously not as good as class 2 for example, but also not as bad as you may interpret just from the matrix.

第8类: = 0.97,非常好,类似于第3类。

Class 8: AUC = 0.97, very good, similar to class 3.

第9类: AUC = 0.82,不是很好,但还是很好。矩阵中的列有很多暗单元,并且数量非常接近,以至于我认为AUC非常好。训练数据中几乎每张图片都显示了该图片,因此预测它经常出现也就不足为奇了。也许其中一些非常暗的单元格仅基于绝对数量很少的图像?

Class 9: AUC = 0.82, not as good, but still good. The column in matrix has so many dark cells, and the numbers are so close, that the AUC is surprisingly good in my opinion. It was present in almost every image in training data, so it's no surprise that it gets predicted as being present often. Maybe some of those very dark cells are based only on a low absolute number of images? This would be interesting to figure out.

Class 10: AUC = 0.09,太糟糕了。对角线上的0非常令人担忧(您的数据是否正确标记?)。根据第一个矩阵的第10行,对于第3类和第9类似乎很困惑(cotton和primary_incision_knives看起来很像secondary_incision_knives吗?)。

Class 10: AUC = 0.09, terrible. A 0 on the diagonal is quite concerning (is your data labelled correctly?). It seems to get confused for classes 3 and 9 very often according to row 10 of the first matrix (do cotton and primary_incision_knives look a lot like secondary_incision_knives?). Maybe also some overfitting to training data.

第11类: AUC = 0.5,没有比随机的更好。性能不佳(矩阵中的得分明显过高)可能是因为该标签存在于大多数训练图像中,但只有少数验证图像中存在。

Class 11: AUC = 0.5, no better than random. Poor performance (and apparantly excessively high scores in matrix) are likely because this label was present in the majority of training images, but only a minority of validation images.

要想对您的数据有更多的了解,我首先会绘制热图以了解多久每个课程都同时出现(一个用于培训,一个用于验证数据)。单元格(i,j)将根据同时包含标签i和j的图像的比例进行着色。这将是一个对称图,对角线上的单元格根据问题中的第一个数字列表进行着色。比较这两个热图,看看它们有何不同,是否可以帮助您解释模型的性能。

To gain more insight in your data, I'd start out by plotting heatmaps of how often every class co-occurs (one for training and one for validation data). Cell (i, j) would be colored according to the ratio of images that contain both labels i and j. This would be a symmetric plot, with on the diagonal cells colored according to those first lists of numbers in your question. Compare the two heatmaps, see where they are very different, and see if that can help to explain your model's performance.

另外,了解这可能很有用(对于这两个而言数据集)每个图像平均具有多少个不同的标签,对于每个单独的标签,它平均与一个图像共享多少个其他标签。例如,我怀疑带有标签10的图像在训练数据中具有相对较少的其他标签。如果网络识别出其他事物,这可能会使网络无法预测标签10,并且如果标签10确实突然在验证数据中更规律地与其他对象共享图像,则会导致性能下降。由于伪代码比单词更容易理解问题,因此打印如下内容可能会很有趣:

Additionally, it may be useful to know (for both datasets) how many different labels each image has on average, and, for every individual label, how many other labels it shares an image with on average. For example, I suspect images with label 10 have relatively few other labels in the training data. This may dissuade the network from predicting label 10 if it recognises other things, and cause poor performance if label 10 does suddenly share images with other objects more regularly in the validation data. Since pseudocode may more easily get the point across than words, it could be interesting to print something like the following:

# Do all of the following once for training data, AND once for validation data    
tot_num_labels = 0
for image in images:
    tot_num_labels += len(image.get_all_labels())
avg_labels_per_image = tot_num_labels / float(num_images)
print("Avg. num labels per image = ", avg_labels_per_image)

for label in range(num_labels):
    tot_shared_labels = 0
    for image in images_with_label(label):
        tot_shared_labels += (len(image.get_all_labels()) - 1)
    avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
    print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")

对于单个数据集,它不会提供太多有用的信息,但是如果您将其用于训练和验证日期集,如果数字非常不同,您可以说它们的分布完全不同

For just a single dataset this doesn't provide much useful information, but if you do it for training and validation sets you can tell that their distributions are quite different if the numbers are very different

最后,我有点担心第一个矩阵中的某些列具有

Finally, I am a bit concerned by how some columns in your first matrix have exactly the same mean prediction appearing over many different rows. I am not quite sure what could cause this, but that may be useful to investigate.

如果您还没有的话,建议您参考数据扩充作为训练数据。由于您使用的是图像,因此可以尝试将现有图像的旋转版本添加到数据中。

If you didn't already, I'd recommend looking into data augmentation for your training data. Since you're working with images, you could try adding rotated versions of existing images to your data.

对于多标签的情况,目标是检测在尝试使用不同类型的对象时,尝试简单地将一堆不同的图像(例如,两个或四个图像)串联在一起也可能很有趣。然后,您可以将其缩小到原始图像尺寸,并在标签分配原始标签集的并集时使用。合并图像的边缘会出现有趣的不连续点,我不知道这是否有害。也许对您的多对象检测而言,这不是一件值得的尝试。

For your multi-label case specifically, where the goal is to detect different types of objects, it may also be interesting to try simply concatenating a bunch of different images (e.g. two or four images) together. You could then scale them down to the original image size, and as labels assign the union of the original sets of labels. You'd get funny discontinuities along the edges where you merge images, I don't know if that'd be harmful. Maybe it wouldn't for your case of multi-object detection, worth a try in my opinion.

这篇关于使用混淆矩阵了解多标签分类器的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆