在Spark中获取树模型的叶子概率 [英] Getting the leaf probabilities of a tree model in spark

查看:99
本文介绍了在Spark中获取树模型的叶子概率的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试重构经过训练的基于火花树的模型(RandomForest或GBT分类器),以使其可以在没有火花的环境中导出. toDebugString方法是一个很好的起点.但是,对于RandomForestClassifier,该字符串仅显示每棵树的预测类,没有相对概率.因此,如果对所有树木的预测取平均值,则会得到错误的结果.

I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The toDebugString method is a good starting point. However, in the case of RandomForestClassifier, the string just shows the predicted class for each tree, without the relative probabilities. So, if you average the prediction for all the trees, you get a wrong result.

一个例子.我们用这种方式表示DecisionTree:

An example. We have a DecisionTree represented in this way:

DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes
  If (feature 21 in {1.0})
   Predict: 0.0
  Else (feature 21 not in {1.0})
   If (feature 10 in {0.0})
    Predict: 0.0
   Else (feature 10 not in {0.0})
    Predict: 1.0

我们可以看到,在节点之后,预测总是看起来为0或1.但是,如果我将这棵单树应用于特征向量,则会得到像[0.1007, 0.8993]这样的概率,它们很完美在意义上讲,因为在训练中设置了与示例矢量最终出现在同一片叶子中的负/正比例与输出概率匹配.

As we can see, following the nodes, it looks like the prediction is always either 0 or 1. However, if I apply this single tree to a vector of features, I get probabilities like [0.1007, 0.8993], and they make perfect sense, since in the training set the proportion of negative/positive which end up in the same leaf as the example vector matches with the output probabilities.

我的问题:这些概率存储在哪里?有没有办法提取它们?如果是这样,怎么办? pyspark解决方案会更好.

My questions: where these probabilities are stored? Is there a way to extract them? If so, how? A pyspark solution would be better.

推荐答案

我正在尝试重构经过训练的基于火花树的模型(RandomForest或GBT分类器),以使其可以在没有火花的环境中导出.

I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The

鉴于为Spark(及其他)模型的实时服务而设计的工具数量不断增加,这可能是在彻底改变车轮.

Given growing number of tools designed for real-time serving of Spark (and other) models, that's probably reinventing the wheel.

但是,如果要从纯Python访问模型内​​部,则最好加载其序列化形式.

However if you want to access model internals from plain Python it is best to load its serialized form.

假设您拥有:

from pyspark.ml.classification import RandomForestClassificationModel

rf_model: RandomForestClassificationModel
path: str  # Absolute path

然后保存模型:

rf_model.write().save(path)

您可以使用支持结构和列表类型混合的Parquet阅读器将其重新加载.模型编写者将写入两个节点数据:

You can load it back using Parquet reader that supports mixes of struct and list types. Model writer writes both node data:

node_data = spark.read.parquet("{}/data".format(path))

node_data.printSchema()

root
 |-- treeID: integer (nullable = true)
 |-- nodeData: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- prediction: double (nullable = true)
 |    |-- impurity: double (nullable = true)
 |    |-- impurityStats: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- rawCount: long (nullable = true)
 |    |-- gain: double (nullable = true)
 |    |-- leftChild: integer (nullable = true)
 |    |-- rightChild: integer (nullable = true)
 |    |-- split: struct (nullable = true)
 |    |    |-- featureIndex: integer (nullable = true)
 |    |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |    |-- element: double (containsNull = true)
 |    |    |-- numCategories: integer (nullable = true)

和树元数据:

tree_meta = spark.read.parquet("{}/treesMetadata".format(path))

tree_meta.printSchema()                            
root
 |-- treeID: integer (nullable = true)
 |-- metadata: string (nullable = true)
 |-- weights: double (nullable = true)

前者提供了您所需的所有信息,因为预测过程基本上是

where the former one provides all the information you need, as the prediction process is basically an aggregation of impurtityStats *.

您还可以直接使用基础Java对象访问此数据

You could also access this data directly using underlying Java objects

from  collections import namedtuple
import numpy as np

LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
    "InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))

def jtree_to_python(jtree):
    def jsplit_to_python(jsplit):
        if jsplit.getClass().toString().endswith(".ContinuousSplit"):
            return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
        else:
            jcat = jsplit.toOld().categories()
            return CategoricalSplit(
                jsplit.featureIndex(),
                [jcat.apply(i) for i in range(jcat.length())])

    def jnode_to_python(jnode):
        prediction = jnode.prediction()        
        stats = np.array(list(jnode.impurityStats().stats()))

        if jnode.numDescendants() != 0:  # InternalNode
            left = jnode_to_python(jnode.leftChild())
            right = jnode_to_python(jnode.rightChild())
            split = jsplit_to_python(jnode.split())

            return InternalNode(left, right, prediction, stats, split)            

        else:
            return LeafNode(prediction, stats) 

    return jnode_to_python(jtree.rootNode())

可以像这样应用于RandomForestModel:

nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]

此外,这种结构还可以轻松地用于两棵树的预测(警告:Python 3.7+之前的版本.有关旧用法,请参阅functools文档):

Furthermore such structure can be easily used to make predictions, for both individual trees (warning: Python 3.7+ ahead. For legacy usage please refer to functools documentation):

from functools import singledispatch

@singledispatch
def should_go_left(split, vector): pass

@should_go_left.register
def _(split: CategoricalSplit, vector):
    return vector[split.feature_index] in split.categories

@should_go_left.register
def _(split: ContinuousSplit, vector):
    return vector[split.feature_index] <= split.threshold

@singledispatch
def predict(node, vector): pass

@predict.register
def _(node: LeafNode, vector):
    return node.prediction, node.impurity

@predict.register
def _(node: InternalNode, vector):
    return predict(
        node.left if should_go_left(node.split, vector) else node.right,
        vector
    )

和森林:

from typing import Iterable, Union

def predict_probability(nodes: Iterable[Union[InternalNode, LeafNode]], vector):
    total = np.array([
        v / v.sum() for _, v in  (
            predict(node, vector) for node in nodes
        )
    ]).sum(axis=0)
    return total / total.sum()

但是,这取决于内部API(以及Scala包范围的访问修饰符的弱点),并且将来可能会中断.

That however depends on the internal API (and weakness of Scala package-scoped access modifiers) and might break in the future.

* DataFrame可以轻松转换为与上面定义的predictpredict_probability函数兼容的结构.

* DataFrame as loaded from data path can be easily transformed to a structure compatible with predict and predict_probability functions defined above.

from pyspark.sql.dataframe import DataFrame 
from itertools import groupby
from operator import itemgetter


def model_data_to_tree(tree_data: DataFrame):
    def dict_to_tree(node_id, nodes):
        node = nodes[node_id]
        prediction = node.prediction
        impurity = np.array(node.impurityStats)

        if node.leftChild == -1 and node.rightChild == -1:
            return LeafNode(prediction, impurity)
        else:
            left = dict_to_tree(node.leftChild, nodes)
            right = dict_to_tree(node.rightChild, nodes)
            feature_index = node.split.featureIndex
            left_value = node.split.leftCategoriesOrThreshold

            split = (
                CategoricalSplit(feature_index, left_value)
                if node.split.numCategories != -1
                else ContinuousSplit(feature_index, left_value[0])
            )

            return InternalNode(left, right, prediction, impurity, split)

    tree_id = itemgetter("treeID")
    rows = tree_data.collect()
    return ([
        dict_to_tree(0, {node.nodeData.id: node.nodeData for node in nodes})
        for tree, nodes in groupby(sorted(rows, key=tree_id), key=tree_id)
    ] if "treeID" in tree_data.columns
    else [dict_to_tree(0, {node.id: node for node in rows})])

这篇关于在Spark中获取树模型的叶子概率的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆