Spark 多类分类示例 [英] Spark Multiclass Classification Example

查看:61
本文介绍了Spark 多类分类示例的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

你们知道我在哪里可以找到 Spark 中多类分类的示例.我花了很多时间在书籍和网络上搜索,到目前为止,我只知道根据文档从最新版本开始是可能的.

Do you guys know where can I find examples of multiclass classification in Spark. I spent a lot of time searching in books and in the web, and so far I just know that it is possible since the latest version according the documentation.

推荐答案

ML

(在 Spark 2.0+ 中推荐)

我们将使用与下面 MLlib 中相同的数据.有两个基本选项.如果 Estimator 支持开箱即用的多类分类(例如随机森林),您可以直接使用它:

We'll use the same data as in the MLlib below. There are two basic options. If Estimator supports multilclass classification out-of-the-box (for example random forest) you can use it directly:

val trainRawDf = trainRaw.toDF

import org.apache.spark.ml.feature.{Tokenizer, CountVectorizer, StringIndexer}
import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.classification.RandomForestClassifier

val transformers = Array(
  new StringIndexer().setInputCol("group").setOutputCol("label"),
  new Tokenizer().setInputCol("text").setOutputCol("tokens"),
  new CountVectorizer().setInputCol("tokens").setOutputCol("features")
)


val rf = new RandomForestClassifier() 
  .setLabelCol("label")
  .setFeaturesCol("features")

val model = new Pipeline().setStages(transformers :+ rf).fit(trainRawDf)

model.transform(trainRawDf)

如果模型仅支持二元分类(逻辑回归)并扩展o.a.s.ml.classification.Classifier,您可以使用 one-vs-rest 策略:

If model supports only binary classification (logistic regression) and extends o.a.s.ml.classification.Classifier you can use one-vs-rest strategy:

import org.apache.spark.ml.classification.OneVsRest
import org.apache.spark.ml.classification.LogisticRegression

val lr = new LogisticRegression() 
  .setLabelCol("label")
  .setFeaturesCol("features")

val ovr = new OneVsRest().setClassifier(lr)

val ovrModel = new Pipeline().setStages(transformers :+ ovr).fit(trainRawDf)

MLLib

根据官方文档此时(MLlib 1.6.0) 以下方法支持多类分类:

According to the official documentation at this moment (MLlib 1.6.0) following methods support multiclass classification:

  • 逻辑回归,
  • 决策树,
  • 随机森林,
  • 朴素贝叶斯

至少有一些例子使用了多类分类:

At least some of the examples use multiclass classification:

  • Naive Bayes example - 3 classes
  • Logistic regression - 10 classes for classifier although only 2 in the example data

忽略方法特定参数的通用框架与 MLlib 中的所有其他方法几乎相同.您必须预处理您的输入以创建具有代表 labelfeatures 的列的任一数据框:

General framework, ignoring method specific arguments, is pretty much the same as for all the other methods in MLlib. You have to pre-processes your input to create either data frame with columns representing label and features:

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)

RDD[LabeledPoint].

Spark 提供了广泛的有用工具,旨在促进这一过程,包括 特征提取器特征变换器管道.

Spark provides broad range of useful tools designed to facilitate this process including Feature Extractors and Feature Transformers and pipelines.

您会在下面找到一个使用随机森林的相当幼稚的例子.

You'll find a rather naive example of using Random Forest below.

首先让我们导入所需的包并创建虚拟数据:

First lets import required packages and create dummy data:

import sqlContext.implicits._
import org.apache.spark.ml.feature.{HashingTF, Tokenizer} 
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.Row
import org.apache.spark.rdd.RDD

case class LabeledRecord(group: String, text: String)

val trainRaw = sc.parallelize(
    LabeledRecord("foo", "foo v a y b  foo") ::
    LabeledRecord("bar", "x bar y bar v") ::
    LabeledRecord("bar", "x a y bar z") ::
    LabeledRecord("foobar", "foo v b bar z") ::
    LabeledRecord("foo", "foo x") ::
    LabeledRecord("foobar", "z y x foo a b bar v") ::
    Nil
)

现在让我们定义所需的转换器和流程Dataset:

Now let's define required transformers and process train Dataset:

// Tokenizer to process text fields
val tokenizer = new Tokenizer()
    .setInputCol("text")
    .setOutputCol("words")

// HashingTF to convert tokens to the feature vector
val hashingTF = new HashingTF()
    .setInputCol("words")
    .setOutputCol("features")
    .setNumFeatures(10)

// Indexer to convert String labels to Double
val indexer = new StringIndexer()
    .setInputCol("group")
    .setOutputCol("label")
    .fit(trainRaw.toDF)


def transfom(rdd: RDD[LabeledRecord]) = {
    val tokenized = tokenizer.transform(rdd.toDF)
    val hashed = hashingTF.transform(tokenized)
    val indexed = indexer.transform(hashed)
    indexed
        .select($"label", $"features")
        .map{case Row(label: Double, features: Vector) =>
            LabeledPoint(label, features)}
}

val train: RDD[LabeledPoint] = transfom(trainRaw)

请注意,indexer 是拟合"在火车数据上的.它只是意味着用作标签的分类值被转换为 doubles.要在新数据上使用分类器,您必须首先使用此 indexer 对其进行转换.

Please note that indexer is "fitted" on the train data. It simply means that categorical values used as the labels are converted to doubles. To use classifier on a new data you have to transform it first using this indexer.

接下来我们可以训练 RF 模型:

Next we can train RF model:

val numClasses = 3
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 10
val featureSubsetStrategy = "auto"
val impurity = "gini"
val maxDepth = 4
val maxBins = 16

val model = RandomForest.trainClassifier(
    train, numClasses, categoricalFeaturesInfo, 
    numTrees, featureSubsetStrategy, impurity,
    maxDepth, maxBins
)

最后测试一下:

val testRaw = sc.parallelize(
    LabeledRecord("foo", "foo  foo z z z") ::
    LabeledRecord("bar", "z bar y y v") ::
    LabeledRecord("bar", "a a  bar a z") ::
    LabeledRecord("foobar", "foo v b bar z") ::
    LabeledRecord("foobar", "a foo a bar") ::
    Nil
)

val test: RDD[LabeledPoint] = transfom(testRaw)

val predsAndLabs = test.map(lp => (model.predict(lp.features), lp.label))
val metrics = new MulticlassMetrics(predsAndLabs)

metrics.precision
metrics.recall

这篇关于Spark 多类分类示例的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆