Pytorch 中类别不平衡的多标签分类 [英] Multilabel classification with class imbalance in Pytorch

查看:25
本文介绍了Pytorch 中类别不平衡的多标签分类的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个多标签分类问题,我正试图用 Pytorch 中的 CNN 解决这个问题.我有 80,000 个训练示例和 7900 个类;每个示例可以同时属于多个类,每个示例的平均类数为 130.

I have a multilabel classification problem, which I am trying to solve with CNNs in Pytorch. I have 80,000 training examples and 7900 classes; every example can belong to multiple classes at the same time, mean number of classes per example is 130.

问题是我的数据集非常不平衡.对于某些课程,我只有大约 900 个示例,大约为 1%.对于过度代表"的类,我有大约 12000 个示例(15%).当我训练模型时,我使用来自 pytorch 的 BCEWithLogitsLoss正权重参数.我按照文档中描述的相同方式计算权重:负例数除以正例数.

The problem is that my dataset is very imbalance. For some classes, I have only ~900 examples, which is around 1%. For "overrepresented" classes I have ~12000 examples (15%). When I train the model I use BCEWithLogitsLoss from pytorch with a positive weights parameter. I calculate the weights the same way as described in the documentation: the number of negative examples divided by the number of positives.

因此,我的模型几乎高估了每个班级……我得到的预测几乎是真实标签的两倍.而我的 AUPRC 仅为 0.18.尽管它比完全不加权要好得多,因为在这种情况下,模型将所有内容预测为零.

As a result, my model overestimates almost every class… Mor minor and major classes I get almost twice as many predictions as true labels. And my AUPRC is just 0.18. Even though it’s much better than no weighting at all, since in this case the model predicts everything as zero.

所以我的问题是,如何提高性能?还有什么我可以做的吗?我尝试了不同的批量抽样技术(对少数类进行过采样),但它们似乎不起作用.

So my question is, how do I improve the performance? Is there anything else I can do? I tried different batch sampling techniques (to oversample minority class), but they don’t seem to work.

推荐答案

我会建议其中一种策略


中引入了一种通过调整损失函数来处理不平衡训练数据的非常有趣的方法Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He 和 Piotr Dollar 用于密集物体检测的焦点损失(ICCV 2017).
他们建议修改二元交叉熵损失,以减少容易分类的样本的损失和梯度,同时集中精力"模型出现严重错误的示例.

A very interesting approach for dealing with un-balanced training data through tweaking of the loss function was introduced in
Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollar Focal Loss for Dense Object Detection (ICCV 2017).
They propose to modify the binary cross entropy loss in a way that decrease the loss and gradient of easily classified examples while "focusing the effort" on examples where the model makes gross errors.

另一种流行的方法是进行硬负挖掘";也就是说,仅针对部分训练示例传播梯度——硬"训练示例.那些.
见,例如:
Abhinav Shrivastava、Abhinav Gupta 和 Ross Girshick 在线训练基于区域的对象检测器困难示例挖掘(CVPR 2016)

Another popular approach is to do "hard negative mining"; that is, propagate gradients only for part of the training examples - the "hard" ones.
see, e.g.:
Abhinav Shrivastava, Abhinav Gupta and Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)

这篇关于Pytorch 中类别不平衡的多标签分类的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆