prSummary in r caret package 用于不平衡数据 [英] prSummary in r caret package for imbalance data

查看:42
本文介绍了prSummary in r caret package 用于不平衡数据的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个不平衡的数据,我想进行分层交叉验证并使用精确召回 auc 作为我的评估指标.

I have an imbalanced data, and I want to do stratified cross validation and use precision recall auc as my evaluation metric.

我在带有分层索引的r包插入符中使用prSummary,在计算性能时遇到错误.

I use prSummary in r package caret with stratified index, and I encounter an error when computing performance.

以下是可以复制的示例.我发现计算 p-r auc 的样本只有十个,而且由于不平衡,只有一个类,因此无法计算 p-r auc.(之所以发现只有10个样本来计算p-r auc,是因为我修改了prSummary强制这个函数打印出数据)

The following is a sample which can be reproduced. I found that there are only ten sample to compute p-r auc, and because of the imbalanced, there is only one class so that it cannot compute p-r auc. (The reason that I found that only ten sample to compute p-r auc is because I modified the prSummary to force this function to print out the data)

library(randomForest)
library(mlbench)
library(caret)

# Load Dataset
data(Sonar)
dataset <- Sonar
x <- dataset[,1:60]
y <- dataset[,61]
# make this data very imbalance
y[4:length(y)] <- "M"
y <- as.factor(y)
dataset$Class <- y

# create index and indexOut 
seed <- 1
set.seed(seed)
folds <- 2
idxAll <- 1:nrow(x)
cvIndex <- createFolds(factor(y), folds, returnTrain = T)
cvIndexOut <- lapply(1:length(cvIndex), function(i){
    idxAll[-cvIndex[[i]]]
})
names(cvIndexOut) <- names(cvIndex)

# set the index, indexOut and prSummaryCorrect
control <- trainControl(index = cvIndex, indexOut = cvIndexOut, 
                            method="cv", summaryFunction = prSummary, classProbs = T)
metric <- "AUC"
set.seed(seed)
mtry <- sqrt(ncol(x))
tunegrid <- expand.grid(.mtry=mtry)
rf_default <- train(Class~., data=dataset, method="rf", metric=metric, tuneGrid=tunegrid, trControl=control)

错误信息如下:

Error in ROCR::prediction(y_pred, y_true) : 
Number of classes is not equal to 2.
ROCR currently supports only evaluation of binary classification tasks. 

推荐答案

我觉得我发现了奇怪的事情...

I think I found the weird thing...

即使我指定了交叉验证索引,汇总函数(无论是 prSummary 还是其他汇总函数)仍然会随机(我不确定)选择十个样本来计算性能.

Even I specified the cross validation index, the summary function(no matter prSummary or other summary function) will still randomly(I am not sure) select ten sample to computing performance.

我的做法是用tryCatch定义一个汇总函数来避免这个错误的发生.

The way I did is defined a summary function with tryCatch to avoid this error occur.

prSummaryCorrect <- function (data, lev = NULL, model = NULL) {
  print(data)
  print(dim(data))
  library(MLmetrics)
  library(PRROC)
  if (length(levels(data$obs)) != 2) 
    stop(levels(data$obs))
  if (length(levels(data$obs)) > 2) 
    stop(paste("Your outcome has", length(levels(data$obs)), 
               "levels. The prSummary() function isn't appropriate."))
  if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")

  res <- tryCatch({
    auc <- MLmetrics::PRAUC(y_pred = data[, lev[2]], y_true = ifelse(data$obs == lev[2], 1, 0))
  }, warning = function(war) {
    print(war)
    auc <- NA
  }, error = function(e){
    print(dim(data))
    auc <- NA
  }, finally = {
    print("finally")
    auc <- NA
  })

  c(AUC = res,
    Precision = precision.default(data = data$pred, reference = data$obs, relevant = lev[2]), 
    Recall = recall.default(data = data$pred, reference = data$obs, relevant = lev[2]), 
    F = F_meas.default(data = data$pred, reference = data$obs, relevant = lev[2]))
}

这篇关于prSummary in r caret package 用于不平衡数据的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆