如何从交叉验证器中获得训练有素的最佳模型 [英] how to obtain the trained best model from a crossvalidator

查看:394
本文介绍了如何从交叉验证器中获得训练有素的最佳模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我建立了一个包含这样的DecisionTreeClassifier(dt)的管道

I built a pipeline including a DecisionTreeClassifier(dt) like this

val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

然后,我使用此管道作为CrossValidator中的估计器,以便获得具有最佳像这样的超参数集的模型

Then I used this pipeline as the estimator in a CrossValidator in order to get a model with the best set of hyperparameters like this

val c_v = new CrossValidator().setEstimator(pipeline).setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).setEstimatorParamMaps(paramGrid).setNumFolds(5)

最后,我可以使用此交叉验证器在训练测试中训练模型

Finally, I could train a model on a training test with this crossvalidator

val model = c_v.fit(train)

但是问题是,我想用参数DecisionTreeClassificationModel.toDebugTree查看训练有素的决策树模型.但是模型是CrossValidatorModel.是的,您可以使用model.bestModel,但是它仍然是Model类型,不能对其应用.toDebugTree.而且我还认为bestModel仍然是pipline,包括labelIndexerfeatureIndexerdtlabelConverter.

But the question is, I want to view the best trained decision tree model with the parameter .toDebugTree of DecisionTreeClassificationModel. But model is a CrossValidatorModel. Yes, you can use model.bestModel, but it is still of type Model, you cannot apply .toDebugTree to it. And also I assume the bestModel is still a pipline including labelIndexer, featureIndexer, dt, labelConverter.

那么有人知道如何从crossvalidator拟合的模型中获取DecisionTree模型,我可以通过toDebugString查看实际模型吗?还是我可以查看DecisionTree模型的任何解决方法?

So does anyone know how I can obtain the decisionTree model from the model fitted by the crossvalidator, which I could view the actual model by toDebugString? Or is there any workaround that I can view the decisionTree model?

推荐答案

类似情况下始终相同-请具体说明类型.

Well, in cases like this one the answer is always the same - be specific about the types.

首先提取管道模型,因为您要训练的是管道:

First extract the pipeline model, since what you are trying to train is a Pipeline :

import org.apache.spark.ml.PipelineModel

val bestModel: Option[PipelineModel] = model.bestModel match {
  case p: PipelineModel => Some(p)
  case _ => None
}

然后,您需要从基础阶段中提取模型.在您的情况下,这是决策树分类模型:

Then you'll need to extract the model from the underlying stage. In your case it's a decision tree classification model :

import org.apache.spark.ml.classification.DecisionTreeClassificationModel

val treeModel: Option[DecisionTreeClassificationModel] = bestModel
  flatMap {
    _.stages.collect {
      case t: DecisionTreeClassificationModel => t
    }.headOption
  }

要打印树,例如:

treeModel.foreach(_.toDebugString)

这篇关于如何从交叉验证器中获得训练有素的最佳模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆