pytorch中的多标签分类 [英] Multi label classification in pytorch

查看:76
本文介绍了pytorch中的多标签分类的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个多标签分类问题.我有 11 个类,大约 4k 个示例.每个示例可以有 1 到 4-5 个标签.目前,我正在使用 log_loss 为每个类分别训练一个分类器.正如您所料,训练 11 个分类器需要相当长的时间,我想尝试另一种方法并仅训练 1 个分类器.这个想法是这个分类器的最后一层将有 11 个节点,并将按类输出一个实数,这些类将通过 sigmoid 转换为 proba.我想优化的损失是所有类的 log_loss 的平均值.

I have a multi-label classification problem. I have 11 classes, around 4k examples. Each example can have from 1 to 4-5 label. At the moment, i'm training a classifier separately for each class with log_loss. As you can expect, it is taking quite some time to train 11 classifier, and i would like to try another approach and to train only 1 classifier. The idea is that the last layer of this classifer would have 11 nodes, and would output a real number by classes which would be converted to a proba by a sigmoid. The loss I want to optimize is the mean of the log_loss on all classes.

不幸的是,我是 pytorch 的菜鸟,即使通过阅读损失的源代码,我也无法弄清楚已经存在的损失之一是否正是我想要的,或者我是否应该创建新的损失,如果是这样,我真的不知道该怎么办.

Unfortunately, i'm some kind of noob with pytorch, and even by reading the source code of the losses, i can't figure out if one of the already existing losses does exactly what i want, or if I should create a new loss, and if that's the case, i don't really know how to do it.

具体来说,我想为批次的每个元素提供一个大小为 11 的向量(其中包含每个标签的实数(越接近无穷大,该类预测为 1 越近),并且1 个大小为 11 的向量(每个真实标签都包含一个 1),并且能够计算所有 11 个标签的平均 log_loss,并根据该损失优化我的分类器.

To be very specific, i want to give for each element of the batch one vector of size 11(which contains a real number for each label (the closer to infinity, the closer this class is predicted to be 1), and 1 vector of size 11 (which contains a 1 at every true label), and be able to compute the mean log_loss on all 11 labels, and optimize my classifier based on that loss.

任何帮助将不胜感激:)

Any help would be greatly appreciated :)

推荐答案

您正在寻找 torch.nn.BCELoss.这是示例代码:

You are looking for torch.nn.BCELoss. Here's example code:

import torch

batch_size = 2
num_classes = 11

loss_fn = torch.nn.BCELoss()

outputs_before_sigmoid = torch.randn(batch_size, num_classes)
sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
target_classes = torch.randint(0, 2, (batch_size, num_classes))  # randints in [0, 2).

loss = loss_fn(sigmoid_outputs, target_classes)

# alternatively, use BCE with logits, on outputs before sigmoid.
loss_fn_2 = torch.nn.BCEWithLogitsLoss()
loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
assert loss == loss2

这篇关于pytorch中的多标签分类的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆