如何访问个别树木由RandomForestClassifier(spark.ml版本)创建的模型? [英] How to access individual trees in a model created by RandomForestClassifier (spark.ml-version)?

查看:1227
本文介绍了如何访问个别树木由RandomForestClassifier(spark.ml版本)创建的模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何访问个别树木火花ML的<生成一个模型href=\"http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier\"相对=nofollow> RandomForestClassifier ?我使用RandomForestClassifier的斯卡拉版本。

How to access individual trees in a model generated by Spark ML's RandomForestClassifier? I am using the Scala version of RandomForestClassifier.

推荐答案

其实它有树木属性:

import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{
  RandomForestClassificationModel, RandomForestClassifier, 
  DecisionTreeClassificationModel
}

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0", "1.0")
  .toMetadata

val data = sqlContext.read.format("libsvm")
  .load("data/mllib/sample_libsvm_data.txt")
  .withColumn("label", $"label".as("label", meta))

val rf: RandomForestClassifier = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
  case t: DecisionTreeClassificationModel => t
}

正如你所看到的唯一的问题是获得正确的类型,所以我们实际上可以使用这些:

As you can see the only problem is to get types right so we can actually use these:

trees.head.transform(data).show(3)
// +-----+--------------------+-------------+-----------+----------+
// |label|            features|rawPrediction|probability|prediction|
// +-----+--------------------+-------------+-----------+----------+
// |  0.0|(692,[127,128,129...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
// |  1.0|(692,[158,159,160...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// |  1.0|(692,[124,125,126...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// +-----+--------------------+-------------+-----------+----------+
// only showing top 3 rows

注意

如果您使用的管道工作,你可以提取单个的树木,以及:

If you work with pipelines you can extract individual trees as well:

import org.apache.spark.ml.Pipeline

val model = new Pipeline().setStages(Array(rf)).fit(data)

// There is only one stage and know its type 
// but lets be thorough
val rfModelOption = model.stages.headOption match {
  case Some(m: RandomForestClassificationModel) => Some(m)
  case _ => None
}

val trees = rfModelOption.map {
  _.trees //  ... as before
}.getOrElse(Array())

这篇关于如何访问个别树木由RandomForestClassifier(spark.ml版本)创建的模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆