插入符号中训练数据的 ROC 曲线 [英] ROC curve from training data in caret

查看:54
本文介绍了插入符号中训练数据的 ROC 曲线的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

使用R包caret,如何根据train()函数的交叉验证结果生成ROC曲线?

Using the R package caret, how can I generate a ROC curve based on the cross-validation results of the train() function?

比如说,我执行以下操作:

Say, I do the following:

data(Sonar)
ctrl <- trainControl(method="cv", 
  summaryFunction=twoClassSummary, 
  classProbs=T)
rfFit <- train(Class ~ ., data=Sonar, 
  method="rf", preProc=c("center", "scale"), 
  trControl=ctrl)

训练函数遍历一系列 mtry 参数并计算 ROC AUC.我想查看相关的 ROC 曲线——我该怎么做?

The training function goes over a range of mtry parameter and calculates the ROC AUC. I would like to see the associated ROC curve -- how do I do that?

注意:如果采样所用的方法是LOOCV,那么rfFit会在rfFit$pred槽中包含一个非空的数据帧,这似乎正是我需要的.但是,对于cv"方法(k 折验证)而不是 LOO,我需要它.

Note: if the method used for sampling is LOOCV, then rfFit will contain a non-null data frame in the rfFit$pred slot, which seems to be exactly what I need. However, I need that for the "cv" method (k-fold validation) rather than LOO.

另外:不,以前版本的 caret 中包含的 roc 函数不是答案——这是一个低级函数,如果你不这样做,你就不能使用它具有每个交叉验证样本的预测概率.

Also: no, roc function that used to be included in former versions of caret is not an answer -- this is a low level function, you can't use it if you don't have the prediction probabilities for each cross-validated sample.

推荐答案

ctrl 中仅缺少 savePredictions = TRUE 参数(这也适用于其他重采样方法):

There is just the savePredictions = TRUE argument missing from ctrl (this also works for other resampling methods):

library(caret)
library(mlbench)
data(Sonar)
ctrl <- trainControl(method="cv", 
                     summaryFunction=twoClassSummary, 
                     classProbs=T,
                     savePredictions = T)
rfFit <- train(Class ~ ., data=Sonar, 
               method="rf", preProc=c("center", "scale"), 
               trControl=ctrl)
library(pROC)
# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2
# Plot:
plot.roc(rfFit$pred$obs[selectedIndices],
         rfFit$pred$M[selectedIndices])

也许我遗漏了一些东西,但一个小问题是 train 估计的 AUC 值总是与 plot.rocpROC::auc(绝对差 <0.005),尽管 twoClassSummary 使用 pROC::auc 来估计 AUC.我认为这是因为 train 的 ROC 是使用单独 CV 集的 AUC 的平均值,这里我们同时计算所有重采样的 AUC 以获得整体 AUC.

Maybe I am missing something, but a small concern is that train always estimates slightly different AUC values than plot.roc and pROC::auc (absolute difference < 0.005), although twoClassSummary uses pROC::auc to estimate the AUC. I assume this occurs because the ROC from train is the average of the AUC using the separate CV-Sets and here we are calculating the AUC over all resamples simultaneously to obtain the overall AUC.

更新 由于这引起了一些关注,这里有一个使用 plotROC::geom_roc() 用于 ggplot2 的解决方案:

Update Since this is getting a bit of attention, here's a solution using plotROC::geom_roc() for ggplot2:

library(ggplot2)
library(plotROC)
ggplot(rfFit$pred[selectedIndices, ], 
       aes(m = M, d = factor(obs, levels = c("R", "M")))) + 
    geom_roc(hjust = -0.4, vjust = 1.5) + coord_equal()

这篇关于插入符号中训练数据的 ROC 曲线的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆