pyspark:在网格搜索为空后获得最佳模型的参数{} [英] pyspark: getting the best model's parameters after a gridsearch is blank {}

查看:449
本文介绍了pyspark:在网格搜索为空后获得最佳模型的参数{}的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

有人可以帮助我从网格搜索中提取性能最佳的模型参数吗?由于某种原因,它是一本空白的字典.

could someone help me extract the best performing model's parameters from my grid search? It's a blank dictionary for some reason.

from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator


train, test = df.randomSplit([0.66, 0.34], seed=12345)

paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01,0.1])
             .addGrid(lr.elasticNetParam, [1.0,])
             .addGrid(lr.maxIter, [3,])
             .build())

evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="buy")
evaluator.setMetricName('areaUnderROC')

cv = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=evaluator,
                          numFolds=2)  
cvModel = cv.fit(train)

> print(cvModel.bestModel) #it looks like I have a valid bestModel
PipelineModel_406e9483e92ebda90524 In [8]:

> cvModel.bestModel.extractParamMap() #fails
 {} In [9]:

> cvModel.bestModel.getRegParam() #also fails
> 
> AttributeError                            Traceback (most recent call
> last) <ipython-input-9-747196173391> in <module>()
> ----> 1 cvModel.bestModel.getRegParam()
> 
> AttributeError: 'PipelineModel' object has no attribute 'getRegParam'

推荐答案

这里有两个不同的问题:

There are two different problems here:

  • 在单独的EstiamtorsTransformers而不是PipelineModel上设置参数.可以使用stages属性访问所有模型.
  • Spark 2.3之前的Python模型根本不包含Params( SPARK- 10931 ).
  • Parameters are set on individual Estiamtors or Transformers not PipelineModel. All models can be accessed using stages property.
  • Before Spark 2.3 Python models don't contain Params at all (SPARK-10931).

因此,除非您使用开发分支,否则必须在分支之间找到感兴趣的模型,访问其_java_obj并获取感兴趣的参数.例如:

So unless you use development branch you have to find the model of interest among branches, access its _java_obj and get parameters of interest. For example:

from pyspark.ml.classification import LogisticRegressionModel

[x._java_obj.getRegParam() 
for x in cvModel.bestModel.stages if isinstance(x, LogisticRegressionModel)]

这篇关于pyspark:在网格搜索为空后获得最佳模型的参数{}的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆