为R中的rpart / ctree包获取预测数据集的每一行的决策树规则/路径模式 [英] Get decision tree rule/path pattern for every row of predicted dataset for rpart/ctree package in R

查看:139
本文介绍了为R中的rpart / ctree包获取预测数据集的每一行的决策树规则/路径模式的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在R中使用 rpart ctree 建立了决策树模型。
我还使用构建的模型预测了一个新的数据集,并获得了预测的概率和类。

I have built a decision tree model in R using rpart and ctree. I also have predicted a new dataset using the built model and got predicted probabilities and classes.

但是,我想提取规则/路径,在对于每个观察值(在预测数据集中)都遵循一个字符串。以表格格式存储此数据,我无需打开R就可以自动解释原因预测。

However, I would like to extract the rule/path, in a single string, for every observation (in predicted dataset) has followed. Storing this data in tabular format, I can explain prediction with reason in a automated manner without opening R.

这意味着我想关注。

ObsID   Probability   PredictedClass   PathFollowed 
    1          0.68             Safe   CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
    2          0.76             Safe   CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
    3          0.88           Unsafe   CarAge > 10 & Type = Van & Country = USA & Price > 15988

我要查找的代码种类是

library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)

这可能需要扩展为多行

predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")


推荐答案

partykit 软件包具有函数 .list.rules.party()目前未导出,但可以利用它来完成您想做的事情。我们尚未导出它的主要原因是,它的输出类型在将来的版本中可能会发生变化。

The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

要获得您在上面描述的预测,您可以执行:

To obtain the predictions you describe above you can do:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

使用 iris的插图数据和 rpart()

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(此处为简洁起见,仅显示了每个物种的第一个观察结果。这对应于索引1、51和101。)

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

并使用 ctree()

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7

这篇关于为R中的rpart / ctree包获取预测数据集的每一行的决策树规则/路径模式的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆