如何从 Spark RandomForestRegressionModel 中获取 maxDepth [英] How to get the maxDepth from a Spark RandomForestRegressionModel

查看:21
本文介绍了如何从 Spark RandomForestRegressionModel 中获取 maxDepth的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在 Spark (2.1.0) 中,我使用 CrossValidator 来训练 RandomForestRegressor,使用 ParamGridBuilder 作为 maxDepthnumTrees:

In Spark (2.1.0) I've used a CrossValidator to train a RandomForestRegressor, using ParamGridBuilder for maxDepth and numTrees:

paramGrid = ParamGridBuilder() \
    .addGrid(rf.maxDepth, [2, 4, 6, 8, 10]) \
    .addGrid(rf.numTrees, [10, 20, 40, 50]) \
    .build()

经过训练,我可以得到最好的树数:

After training, I can get the best number of trees:

regressor = cvModel.bestModel.stages[len(cvModel.bestModel.stages) - 1]

print(regressor.getNumTrees)

但我不知道如何获得最佳的 maxDepth.我已经阅读了 文档 而我没有看到我错过了什么.

but I can't work out how to get the best maxDepth. I've read the documentation and I don't see what I'm missing.

我注意到我可以遍历所有树并找到每棵树的深度,例如

I'd note that I can iterate through all the trees and find the depth of each one, eg

regressor.trees[0].depth

不过,这似乎我遗漏了一些东西.

This seems like I'm missing something though.

推荐答案

不幸的是,Spark 2.3 之前的 PySpark RandomForestRegressionModel,不像它的 Scala 对应物,不存储上游 Estimator Params,但您应该能够直接从 JVM 对象中检索它.使用简单的猴子补丁:

Unfortunately PySpark RandomForestRegressionModel before Spark 2.3, unlike its Scala counterpart, doesn't store upstream Estimator Params, but you should be able to retrieve it directly from the JVM object. With a simple monkey patch:

from pyspark.ml.regression import RandomForestRegressionModel

RandomForestRegressionModel.getMaxDepth = (
    lambda self: self._java_obj.getMaxDepth()
)

你可以:

cvModel.bestModel.stages[-1].getMaxDepth()

这篇关于如何从 Spark RandomForestRegressionModel 中获取 maxDepth的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆