如何使用随机森林在星火管道 [英] How to use RandomForest in Spark Pipeline

查看:263
本文介绍了如何使用随机森林在星火管道的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想我的调谐用网格搜索和火花交叉验证模式。在火花,必须把基础模型在管道中,<一个href=\"http://spark.apache.org/docs/latest/ml-guide.html#example-model-selection-via-cross-validation\"相对=nofollow>管道的办公演示使用 LogistictRegression 作为一个基本模型,它可以是新的作为对象。但是,随机森林模型不能按客户code,因此它似乎无法使用随机森林在管道API。我不想重新创建轮,所以任何人可以给一些建议?
谢谢


解决方案

  

不过,随机森林模型无法通过客户端code新的,所以似乎不能在管道API使用随机森林。


嗯,这是事实,但你只是试图用一个错误的类。相反 mllib.tree.RandomForest的你应该使用 ml.classification.RandomForestClassifier 。这是基于从MLlib文档一个例子。

进口org.apache.spark.ml.classification.RandomForestClassifier
进口org.apache.spark.ml.Pipeline
进口org.apache.spark.ml.feature.StringIndexer
进口org.apache.spark.mllib.linalg.Vector
进口org.apache.spark.mllib.util.MLUtils
进口sqlContext.implicits._案例类记录(类别:字符串,功能:矢量)VAL数据= MLUtils.loadLibSVMFile(SC,数据/ mllib / sample_libsvm_data.txt)
VAL拆分= data.randomSplit(阵列(0.7 0.3))
VAL(trainData,TESTDATA)=(分割(0),拆分(1))VAL trainDF = trainData.map(LP =&GT;记录(lp.label.toString,lp.features))。toDF
VAL testDF = testData.map(LP =&GT;记录(lp.label.toString,lp.features))。toDFVAL索引=新StringIndexer()
  .setInputCol(类别)
  .setOutputCol(标签)VAL RF =新RandomForestClassifier()
    .setNumTrees(3)
    .setFeatureSubsetStrategy(自动)
    .setImpurity(基尼)
    .setMaxDepth(4)
    .setMaxBins(32)VAL管道=新管道()
    .setStages(阵列(索引器,RF))VAL模型= pipeline.fit(trainDF)model.transform(testDF)

有一件事情我无法弄清楚这里。至于我可以告诉它应该可以直接使用从 LabeledPoints 提取标签,但由于某些原因,它不工作, pipeline.fit 引发 IllegalArgumentExcetion


  

RandomForestClassifier给予无效的标签栏标签输入,没有指定类的数量。


因此​​,丑陋伎俩与 StringIndexer 。申请后,我们得到需要的属性( {丘壑:1.0,0.0],类型:正常,名:标签} ),但在毫升似乎没有它工作得很好。

I want to tunning my model with grid search and cross validation with spark. In the spark, it must put the base model in a pipeline, the office demo of pipeline use the LogistictRegression as an base model, which can be new as an object. However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api. I don't want to recreate an wheel, so can anybody give some advice? Thanks

解决方案

However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api.

Well, that is true but you simply trying to use a wrong class. Instead of mllib.tree.RandomForest you should use ml.classification.RandomForestClassifier. Here is an example based on the one from MLlib docs.

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._ 

case class Record(category: String, features: Vector)

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))

val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("label")

val rf  = new RandomForestClassifier()
    .setNumTrees(3)
    .setFeatureSubsetStrategy("auto")
    .setImpurity("gini")
    .setMaxDepth(4)
    .setMaxBins(32)

val pipeline = new Pipeline()
    .setStages(Array(indexer, rf))

val model = pipeline.fit(trainDF)

model.transform(testDF)

There is one thing I couldn't figure out here. As far as I can tell it should be possible to use labels extracted from LabeledPoints directly, but for some reason it doesn't work and pipeline.fit raises IllegalArgumentExcetion:

RandomForestClassifier was given input with invalid label column label, without the number of classes specified.

Hence the ugly trick with StringIndexer. After applying we get required attributes ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}) but some classes in ml seem to work just fine without it.

这篇关于如何使用随机森林在星火管道的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆