如何在 Spark Pipeline 中使用 RandomForest [英] How to use RandomForest in Spark Pipeline

查看:43
本文介绍了如何在 Spark Pipeline 中使用 RandomForest的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想用网格搜索和 spark 交叉验证来调整我的模型.在 spark 中,它必须将基础模型放入管道中,管道的office demo 使用LogistictRegression 作为基础模型,它可以是新的对象.但是,RandomForest 模型不能被客户端代码new,因此它似乎无法在管道 api 中使用 RandomForest.我不想重新创建一个轮子,所以有人可以给一些建议吗?谢谢

I want to tunning my model with grid search and cross validation with spark. In the spark, it must put the base model in a pipeline, the office demo of pipeline use the LogistictRegression as an base model, which can be new as an object. However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api. I don't want to recreate an wheel, so can anybody give some advice? Thanks

推荐答案

但是,RandomForest 模型不能通过客户端代码新建,因此似乎无法在管道 api 中使用 RandomForest.

However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api.

嗯,这是真的,但你只是试图使用错误的类.您应该使用 ml.classification.RandomForestClassifier 而不是 mllib.tree.RandomForest.这是一个基于 来自 MLlib 文档的示例.

Well, that is true but you simply trying to use a wrong class. Instead of mllib.tree.RandomForest you should use ml.classification.RandomForestClassifier. Here is an example based on the one from MLlib docs.

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._ 

case class Record(category: String, features: Vector)

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))

val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("label")

val rf  = new RandomForestClassifier()
    .setNumTrees(3)
    .setFeatureSubsetStrategy("auto")
    .setImpurity("gini")
    .setMaxDepth(4)
    .setMaxBins(32)

val pipeline = new Pipeline()
    .setStages(Array(indexer, rf))

val model = pipeline.fit(trainDF)

model.transform(testDF)

有一件事我在这里想不通.据我所知应该可以直接使用从 LabeledPoints 中提取的标签,但由于某种原因它不起作用并且 pipeline.fit 引发 IllegalArgumentExcetion:

There is one thing I couldn't figure out here. As far as I can tell it should be possible to use labels extracted from LabeledPoints directly, but for some reason it doesn't work and pipeline.fit raises IllegalArgumentExcetion:

RandomForestClassifier 的输入带有无效的标签列标签,但没有指定类的数量.

RandomForestClassifier was given input with invalid label column label, without the number of classes specified.

这就是 StringIndexer 的丑陋技巧.应用后我们得到必需的属性 ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}) 但有些类在ml 没有它似乎也能正常工作.

Hence the ugly trick with StringIndexer. After applying we get required attributes ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}) but some classes in ml seem to work just fine without it.

这篇关于如何在 Spark Pipeline 中使用 RandomForest的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆