如何有效地计算 PyTorch 中的混淆矩阵? [英] How do I calculate the confusion matrix in PyTorch efficiently?

查看:24
本文介绍了如何有效地计算 PyTorch 中的混淆矩阵?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个包含我的预测的张量和一个包含我的二元分类问题的实际标签的张量.如何有效地计算混淆矩阵?

I have a tensor that contains my predictions and a tensor that contains the actual labels for my binary classification problem. How can I calculate the confusion matrix efficiently?

推荐答案

在我使用 for 循环的第一个版本被证明效率低下之后,这是我目前想到的最快的解决方案,对于两个等维张量 预测真相:

After my first version using a for-loop has proven inefficient, this is the fastest solution I came up with so far, for two equal-dimensional tensors prediction and truth:

def confusion(prediction, truth):
    confusion_vector = prediction / truth

    true_positives = torch.sum(confusion_vector == 1).item()
    false_positives = torch.sum(confusion_vector == float('inf')).item()
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
    false_negatives = torch.sum(confusion_vector == 0).item()

    return true_positives, false_positives, true_negatives, false_negatives

https://gist.github.com/the-的评论版本和测试用例低音/cae9f3976866776dea17a5049013258d

这篇关于如何有效地计算 PyTorch 中的混淆矩阵?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆