SPARK,ML,Tuning,CrossValidator:访问指标 [英] SPARK, ML, Tuning, CrossValidator: access the metrics

查看:112
本文介绍了SPARK,ML,Tuning,CrossValidator:访问指标的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

为了构建NaiveBayes多类分类器,我正在使用CrossValidator在管道中选择最佳参数:

In order to build a NaiveBayes multiclass classifier, I am using a CrossValidator to select the best parameters in my pipeline:

val cv = new CrossValidator()
        .setEstimator(pipeline)
        .setEstimatorParamMaps(paramGrid)
        .setEvaluator(new MulticlassClassificationEvaluator)
        .setNumFolds(10)

val cvModel = cv.fit(trainingSet)

管道按以下顺序包含常用的转换器和估计器:Tokenizer,StopWordsRemover,HashingTF,IDF,最后是NaiveBayes.

The pipeline contains usual transformers and estimators in the following order: Tokenizer, StopWordsRemover, HashingTF, IDF and finally the NaiveBayes.

是否可以访问为最佳模型计算的指标?

Is it possible to access the metrics calculated for best model?

理想情况下,我想访问所有模型的指标,以了解更改参数如何改变分类的质量. 但是目前,最好的模型已经足够了.

Ideally, I would like to access the metrics of all models to see how changing the parameters is changing the quality of the classification. But for the moment, the best model is good enough.

仅供参考,我正在使用Spark 1.6.0

FYI, I am using Spark 1.6.0

推荐答案

这是我的方法:

val pipeline = new Pipeline()
  .setStages(Array(tokenizer, stopWordsFilter, tf, idf, word2Vec, featureVectorAssembler, categoryIndexerModel, classifier, categoryReverseIndexer))

...

val paramGrid = new ParamGridBuilder()
  .addGrid(tf.numFeatures, Array(10, 100))
  .addGrid(idf.minDocFreq, Array(1, 10))
  .addGrid(word2Vec.vectorSize, Array(200, 300))
  .addGrid(classifier.maxDepth, Array(3, 5))
  .build()

paramGrid.size // 16 entries

...

// Print the average metrics per ParamGrid entry
val avgMetricsParamGrid = crossValidatorModel.avgMetrics

// Combine with paramGrid to see how they affect the overall metrics
val combined = paramGrid.zip(avgMetricsParamGrid)

...

val bestModel = crossValidatorModel.bestModel.asInstanceOf[PipelineModel]

// Explain params for each stage
val bestHashingTFNumFeatures = bestModel.stages(2).asInstanceOf[HashingTF].explainParams
val bestIDFMinDocFrequency = bestModel.stages(3).asInstanceOf[IDFModel].explainParams
val bestWord2VecVectorSize = bestModel.stages(4).asInstanceOf[Word2VecModel].explainParams
val bestDecisionTreeDepth = bestModel.stages(7).asInstanceOf[DecisionTreeClassificationModel].explainParams

这篇关于SPARK,ML,Tuning,CrossValidator:访问指标的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆