以安全正确的方式使用RandomForestClassifier的predict_proba()函数 [英] Using the predict_proba() function of RandomForestClassifier in the safe and right way

查看:735
本文介绍了以安全正确的方式使用RandomForestClassifier的predict_proba()函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用Scikit-learn将机器学习算法应用于我的数据集.有时我需要使标签/类的概率恢复为标签/类的自身.我不希望将垃圾邮件/非垃圾邮件作为电子邮件的标签,而仅希望举例说明:给定电子邮件为垃圾邮件的概率为0.78.

I'm using Scikit-learn to apply machine learning algorithm on my datasets. Sometimes I need to have the probabilities of labels/classes instated of the labels/classes themselves. Instead of having Spam/Not Spam as labels of emails, I wish to have only for example: 0.78 probability a given email is Spam.

出于这个目的,我将Random_ForestClassifier与预测_proba()一起使用,如下所示:

For such purpose, I'm using predict_proba() with RandomForestClassifier as following:

clf = RandomForestClassifier(n_estimators=10, max_depth=None,
    min_samples_split=1, random_state=0)
scores = cross_val_score(clf, X, y)
print(scores.mean())

classifier = clf.fit(X,y)
predictions = classifier.predict_proba(Xtest)
print(predictions)

我得到了那些结果:

 [ 0.4  0.6]
 [ 0.1  0.9]
 [ 0.2  0.8]
 [ 0.7  0.3]
 [ 0.3  0.7]
 [ 0.3  0.7]
 [ 0.7  0.3]
 [ 0.4  0.6]

第二列用于分类:垃圾邮件.但是,我对结果有两个主要问题,对此我不确定.第一个问题是结果代表标签的概率,而不受我的数据大小的影响吗?第二个问题是,结果仅显示一位数字,在某些情况下0.701概率与0.708完全不同,这不是很明确.例如,有什么方法可以获取下一个5位数字吗?

Where the second column is for class: Spam. However, I have two main issues with the results about which I am not confident. The first issue is that the results represent the probabilities of the labels without being affected by the size of my data? The second issue is that the results show only one digit which is not very specific in some cases where the 0.701 probability is very different from 0.708. Is there any way to get the next 5 digit for example?

推荐答案

  1. 我的结果中得到一位以上的数字,您确定不是由于您的数据集引起的吗? (例如,使用非常小的数据集将产生简单的决策树,从而产生简单"的概率).否则,它可能只是显示一位数字的显示屏,而是尝试打印predictions[0,0].

我不确定您的意思是概率不受数据量的影响".如果您担心自己不想预测垃圾邮件,例如太多垃圾邮件,通常要做的是使用阈值t,以便预测proba(label==1) > t为1.这样,您可以使用阈值来平衡您的预测,例如限制垃圾邮件的全球概率.如果要全局分析模型,通常会计算接收器工作特性(ROC)曲线的曲线下面积(AUC)(请参阅Wikipedia文章

I am not sure to understand what you mean by "the probabilities aren't affected by the size of my data". If your concern is that you don't want to predict, eg, too many spams, what is usually done is to use a threshold t such that you predict 1 if proba(label==1) > t. This way you can use the threshold to balance your predictions, for example to limit the global probabilty of spams. And if you want to globally analyse your model, we usually compute the Area under the curve (AUC) of the Receiver operating characteristic (ROC) curve (see wikipedia article here). Basically the ROC curve is a description of your predictions depending on the threshold t.

希望有帮助!

这篇关于以安全正确的方式使用RandomForestClassifier的predict_proba()函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆