在决策树中为每个数据点找到对应的叶节点(scikit-learn) [英] Finding a corresponding leaf node for each data point in a decision tree (scikit-learn)

查看:52
本文介绍了在决策树中为每个数据点找到对应的叶节点(scikit-learn)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用python 3.4中scikit-learn包中的决策树分类器,我想为每个输入数据点获取对应的叶节点ID.

I'm using decision tree classifier from the scikit-learn package in python 3.4, and I want to get the corresponding leaf node id for each of my input data point.

例如,我的输入可能像这样:

For example, my input might look like this:

array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2]])

,并假设相应的叶节点分别为16、5和45.我希望输出为:

and let's suppose the corresponding leaf nodes are 16, 5 and 45 respectively. I want my output to be:

leaf_node_id = array([16, 5, 45])

我已经阅读了scikit-learn邮件列表和有关SF的相关问题,但仍然无法正常使用.这是我在邮件列表中找到的一些提示,但仍然无法使用.

I have read through the scikit-learn mailing list and related questions on SF but I still can't get it to work. Here is some hint I found on the mailing list, but still does not work.

http://sourceforge.net/p/scikit-learn/mailman/message/31728624/

在一天结束时,我只想拥有一个函数Ge​​tLeafNode(clf,X_valida),使其输出为相应叶节点的列表.下面是重现我收到的错误的代码.因此,任何建议将不胜感激.

At the end of the day, I just want to have a function GetLeafNode(clf, X_valida) such that its output is a list of corresponding leaf nodes. Below is the code that reproduces the error I received. So, any suggestion will be very appreciated.

from sklearn.datasets import load_iris
from sklearn import tree

# load data and divide it to train and validation
iris = load_iris()

num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]

y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]

# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)

# Now I want to know the corresponding leaf node id for each of my training data point
clf.tree_.apply(X_train)

# This gives the error message below:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-17-2ecc95213752> in <module>()
----> 1 clf.tree_.apply(X_train)

_tree.pyx in sklearn.tree._tree.Tree.apply (sklearn/tree/_tree.c:19595)()

ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double'

推荐答案

我终于使它起作用了.这是基于我在scikit-learn中的消息的解决方案邮件列表:

I finally got it to work. Here is one solution based on my correspondence message in the scikit-learn mailing list:

在scikit-learn版本0.16.1之后, clf.tree _ 中实现了apply方法,因此,我遵循以下步骤:

After scikit-learn version 0.16.1, apply method is implemented in clf.tree_, therefore, I followed the following steps:

  1. 将scikit-learn更新到最新版本(0.16.1),以便您可以使用 clf.tree _
  2. 中的 apply 方法
  3. 使用以下命令将输入​​数据数组( X_train X_valida )从 float64 转换为 float32 .X_train = X_train.astype('float32')
  4. 现在,您可以通过以下方式使用 apply 方法: clf.tree_.apply(X_train),您将获得每个数据点的叶子节点ID.
  1. update scikit-learn to the latest version (0.16.1) so that you can use apply method from clf.tree_
  2. convert the input data arrays (X_train, X_valida) from float64 to float32 using: X_train = X_train.astype('float32')
  3. Now you can use apply method in this way: clf.tree_.apply(X_train) and you will get the leaf node id for each data point.

这是最终代码:

from sklearn.datasets import load_iris
from sklearn import tree

# load data and divide it to train and validation
iris = load_iris()

num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]

y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]

# convert data to float32
X_train = X_train.astype('float32')

# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)

# Now I want to know the corresponding leaf node id for each of my training data point
clf.tree_.apply(X_train)

# This gives the leaf node id:
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2])

这篇关于在决策树中为每个数据点找到对应的叶节点(scikit-learn)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆