如何从交叉验证器中获得训练有素的最佳模型 [英] how to obtain the trained best model from a crossvalidator
问题描述
我建立了一个包含这样的DecisionTreeClassifier(dt)的管道
I built a pipeline including a DecisionTreeClassifier(dt) like this
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
然后,我使用此管道作为CrossValidator中的估计器,以便获得具有最佳像这样的超参数集的模型
Then I used this pipeline as the estimator in a CrossValidator in order to get a model with the best set of hyperparameters like this
val c_v = new CrossValidator().setEstimator(pipeline).setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).setEstimatorParamMaps(paramGrid).setNumFolds(5)
最后,我可以使用此交叉验证器在训练测试中训练模型
Finally, I could train a model on a training test with this crossvalidator
val model = c_v.fit(train)
但是问题是,我想用参数DecisionTreeClassificationModel
的.toDebugTree
查看训练有素的决策树模型.但是模型是CrossValidatorModel
.是的,您可以使用model.bestModel
,但是它仍然是Model
类型,不能对其应用.toDebugTree
.而且我还认为bestModel仍然是pipline,包括labelIndexer
,featureIndexer
,dt
,labelConverter
.
But the question is, I want to view the best trained decision tree model with the parameter .toDebugTree
of DecisionTreeClassificationModel
. But model is a CrossValidatorModel
. Yes, you can use model.bestModel
, but it is still of type Model
, you cannot apply .toDebugTree
to it. And also I assume the bestModel is still a pipline including labelIndexer
, featureIndexer
, dt
, labelConverter
.
那么有人知道如何从crossvalidator
拟合的模型中获取DecisionTree模型,我可以通过toDebugString
查看实际模型吗?还是我可以查看DecisionTree模型的任何解决方法?
So does anyone know how I can obtain the decisionTree model from the model fitted by the crossvalidator
, which I could view the actual model by toDebugString
? Or is there any workaround that I can view the decisionTree model?
推荐答案
在类似情况下始终相同-请具体说明类型.
Well, in cases like this one the answer is always the same - be specific about the types.
首先提取管道模型,因为您要训练的是管道:
First extract the pipeline model, since what you are trying to train is a Pipeline :
import org.apache.spark.ml.PipelineModel
val bestModel: Option[PipelineModel] = model.bestModel match {
case p: PipelineModel => Some(p)
case _ => None
}
然后,您需要从基础阶段中提取模型.在您的情况下,这是决策树分类模型:
Then you'll need to extract the model from the underlying stage. In your case it's a decision tree classification model :
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
val treeModel: Option[DecisionTreeClassificationModel] = bestModel
flatMap {
_.stages.collect {
case t: DecisionTreeClassificationModel => t
}.headOption
}
要打印树,例如:
treeModel.foreach(_.toDebugString)
这篇关于如何从交叉验证器中获得训练有素的最佳模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!