在Spark中获取树模型的叶子概率 [英] Getting the leaf probabilities of a tree model in spark
问题描述
我正在尝试重构经过训练的基于火花树的模型(RandomForest或GBT分类器),以使其可以在没有火花的环境中导出. toDebugString
方法是一个很好的起点.但是,对于RandomForestClassifier
,该字符串仅显示每棵树的预测类,没有相对概率.因此,如果对所有树木的预测取平均值,则会得到错误的结果.
I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The toDebugString
method is a good starting point. However, in the case of RandomForestClassifier
, the string just shows the predicted class for each tree, without the relative probabilities. So, if you average the prediction for all the trees, you get a wrong result.
一个例子.我们用这种方式表示DecisionTree
:
An example. We have a DecisionTree
represented in this way:
DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes
If (feature 21 in {1.0})
Predict: 0.0
Else (feature 21 not in {1.0})
If (feature 10 in {0.0})
Predict: 0.0
Else (feature 10 not in {0.0})
Predict: 1.0
我们可以看到,在节点之后,预测总是看起来为0或1.但是,如果我将这棵单树应用于特征向量,则会得到像[0.1007, 0.8993]
这样的概率,它们很完美在意义上讲,因为在训练中设置了与示例矢量最终出现在同一片叶子中的负/正比例与输出概率匹配.
As we can see, following the nodes, it looks like the prediction is always either 0 or 1. However, if I apply this single tree to a vector of features, I get probabilities like [0.1007, 0.8993]
, and they make perfect sense, since in the training set the proportion of negative/positive which end up in the same leaf as the example vector matches with the output probabilities.
我的问题:这些概率存储在哪里?有没有办法提取它们?如果是这样,怎么办? pyspark
解决方案会更好.
My questions: where these probabilities are stored? Is there a way to extract them? If so, how? A pyspark
solution would be better.
推荐答案
我正在尝试重构经过训练的基于火花树的模型(RandomForest或GBT分类器),以使其可以在没有火花的环境中导出.
I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The
鉴于为Spark(及其他)模型的实时服务而设计的工具数量不断增加,这可能是在彻底改变车轮.
Given growing number of tools designed for real-time serving of Spark (and other) models, that's probably reinventing the wheel.
但是,如果要从纯Python访问模型内部,则最好加载其序列化形式.
However if you want to access model internals from plain Python it is best to load its serialized form.
假设您拥有:
from pyspark.ml.classification import RandomForestClassificationModel
rf_model: RandomForestClassificationModel
path: str # Absolute path
然后保存模型:
rf_model.write().save(path)
您可以使用支持结构和列表类型混合的Parquet阅读器将其重新加载.模型编写者将写入两个节点数据:
You can load it back using Parquet reader that supports mixes of struct and list types. Model writer writes both node data:
node_data = spark.read.parquet("{}/data".format(path))
node_data.printSchema()
root
|-- treeID: integer (nullable = true)
|-- nodeData: struct (nullable = true)
| |-- id: integer (nullable = true)
| |-- prediction: double (nullable = true)
| |-- impurity: double (nullable = true)
| |-- impurityStats: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- rawCount: long (nullable = true)
| |-- gain: double (nullable = true)
| |-- leftChild: integer (nullable = true)
| |-- rightChild: integer (nullable = true)
| |-- split: struct (nullable = true)
| | |-- featureIndex: integer (nullable = true)
| | |-- leftCategoriesOrThreshold: array (nullable = true)
| | | |-- element: double (containsNull = true)
| | |-- numCategories: integer (nullable = true)
和树元数据:
tree_meta = spark.read.parquet("{}/treesMetadata".format(path))
tree_meta.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
where the former one provides all the information you need, as the prediction process is basically an aggregation of impurtityStats
*.
您还可以直接使用基础Java对象访问此数据
You could also access this data directly using underlying Java objects
from collections import namedtuple
import numpy as np
LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
"InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))
def jtree_to_python(jtree):
def jsplit_to_python(jsplit):
if jsplit.getClass().toString().endswith(".ContinuousSplit"):
return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
else:
jcat = jsplit.toOld().categories()
return CategoricalSplit(
jsplit.featureIndex(),
[jcat.apply(i) for i in range(jcat.length())])
def jnode_to_python(jnode):
prediction = jnode.prediction()
stats = np.array(list(jnode.impurityStats().stats()))
if jnode.numDescendants() != 0: # InternalNode
left = jnode_to_python(jnode.leftChild())
right = jnode_to_python(jnode.rightChild())
split = jsplit_to_python(jnode.split())
return InternalNode(left, right, prediction, stats, split)
else:
return LeafNode(prediction, stats)
return jnode_to_python(jtree.rootNode())
可以像这样应用于RandomForestModel
:
nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
此外,这种结构还可以轻松地用于两棵树的预测(警告:Python 3.7+之前的版本.有关旧用法,请参阅functools
文档):
Furthermore such structure can be easily used to make predictions, for both individual trees (warning: Python 3.7+ ahead. For legacy usage please refer to functools
documentation):
from functools import singledispatch
@singledispatch
def should_go_left(split, vector): pass
@should_go_left.register
def _(split: CategoricalSplit, vector):
return vector[split.feature_index] in split.categories
@should_go_left.register
def _(split: ContinuousSplit, vector):
return vector[split.feature_index] <= split.threshold
@singledispatch
def predict(node, vector): pass
@predict.register
def _(node: LeafNode, vector):
return node.prediction, node.impurity
@predict.register
def _(node: InternalNode, vector):
return predict(
node.left if should_go_left(node.split, vector) else node.right,
vector
)
和森林:
from typing import Iterable, Union
def predict_probability(nodes: Iterable[Union[InternalNode, LeafNode]], vector):
total = np.array([
v / v.sum() for _, v in (
predict(node, vector) for node in nodes
)
]).sum(axis=0)
return total / total.sum()
但是,这取决于内部API(以及Scala包范围的访问修饰符的弱点),并且将来可能会中断.
That however depends on the internal API (and weakness of Scala package-scoped access modifiers) and might break in the future.
* DataFrame
可以轻松转换为与上面定义的predict
和predict_probability
函数兼容的结构.
* DataFrame
as loaded from data
path can be easily transformed to a structure compatible with predict
and predict_probability
functions defined above.
from pyspark.sql.dataframe import DataFrame
from itertools import groupby
from operator import itemgetter
def model_data_to_tree(tree_data: DataFrame):
def dict_to_tree(node_id, nodes):
node = nodes[node_id]
prediction = node.prediction
impurity = np.array(node.impurityStats)
if node.leftChild == -1 and node.rightChild == -1:
return LeafNode(prediction, impurity)
else:
left = dict_to_tree(node.leftChild, nodes)
right = dict_to_tree(node.rightChild, nodes)
feature_index = node.split.featureIndex
left_value = node.split.leftCategoriesOrThreshold
split = (
CategoricalSplit(feature_index, left_value)
if node.split.numCategories != -1
else ContinuousSplit(feature_index, left_value[0])
)
return InternalNode(left, right, prediction, impurity, split)
tree_id = itemgetter("treeID")
rows = tree_data.collect()
return ([
dict_to_tree(0, {node.nodeData.id: node.nodeData for node in nodes})
for tree, nodes in groupby(sorted(rows, key=tree_id), key=tree_id)
] if "treeID" in tree_data.columns
else [dict_to_tree(0, {node.id: node for node in rows})])
这篇关于在Spark中获取树模型的叶子概率的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!