python - 获取 DecisionTreeRegressor 中终端(叶)节点的所有值
问题描述
决策树将节点拆分,直到出现一些破坏条件,并使用任何节点中值的平均值作为预测。
我想获得这样一个节点中的所有值,而不仅仅是平均值,然后执行更复杂的操作。我正在使用sklearn。我没有找到任何答案,只是一种使用DecisionTreeRegressor.tree_.value
.
怎么做?
解决方案
AFAIK 对此没有任何 API 方法,但您当然可以通过编程方式获取它们。
让我们制作一些虚拟数据并首先构建一个回归树来证明这一点:
import numpy as np
from sklearn.tree import DecisionTreeRegressor, export_graphviz
# dummy data
rng = np.random.RandomState(1) # for reproducibility
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
estimator = DecisionTreeRegressor(max_depth=3)
estimator.fit(X, y)
import graphviz
dot_data = export_graphviz(estimator, out_file=None)
graph = graphviz.Source(dot_data)
graph
这是我们的决策树图:
从中可以明显看出我们有 8 个叶子,其中描述了样本的数量和每个叶子的平均值。
这里的关键命令是apply
:
on_leaf = estimator.apply(X)
on_leaf
# result:
array([ 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6,
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14])
on_leaf
长度等于我们的数据X
和结果y
;它给出了每个样本结束的节点的索引(所有节点on_leaf
都是终端节点,即叶子)。它的唯一值的数量等于我们的叶子的数量,这里是 8:
len(np.unique(on_leaf))
# 8
并on_leaf[k]
给出y[k]
结束的节点数。
现在我们可以得到y
8 片叶子中每一片叶子的实际值:
leaves = []
for i in np.unique(on_leaf):
leaves.append(y[np.argwhere(on_leaf==i)])
len(leaves)
# 8
让我们验证一下,根据我们的图,第一片叶子只有一个样本,其值为-1.149
(因为它是单样本叶子,所以样本的值等于均值):
leaves[0]
# array([[-1.1493464]])
看起来挺好的。那么第二片叶子呢,有 10 个样本,平均值为-0.173
?
leaves[1]
# result:
array([[ 0.09131401],
[ 0.09668352],
[ 0.13651039],
[ 0.19403525],
[-0.12383814],
[ 0.26365828],
[ 0.41252216],
[ 0.44546446],
[ 0.47215529],
[-0.26319138]])
len(leaves[1])
# 10
leaves[1].mean()
# 0.17253138570808904
依此类推 - 最后检查最后一片叶子(#7),有 4 个样本和平均值-0.99
:
leaves[7]
# result:
array([[-0.99994398],
[-0.99703245],
[-0.99170146],
[-0.9732277 ]])
leaves[7].mean()
# -0.9904763973694366
总结一下:
使用 data X
、y
results 和决策树回归器estimator
需要的是:
on_leaf = estimator.apply(X)
leaves = []
for i in np.unique(on_leaf):
leaves.append(y[np.argwhere(on_leaf==i)])
推荐阅读
- azure - Azure IoT 中心框架 3.5
- c++ - 使用 C++ 标准库以对数时间进行 Heapify
- algorithm - 在 O(n^2 * log n) 中 O(n^2) 来自哪里?
- java - 每个路径 Spring 多个 HandlerMethodArgumentResolver
- centos7 - sshd 执行新连接而不是 fork
- kubernetes - 为什么 dig 不通过 dns 名称解析 K8s 服务,而 nslookup 没有问题?
- mysql - 定期将数据从 AWS RDS (MySQL) 复制到另一台服务器(EC2 实例)
- c# - 在视图中检查模型是否有数据
- c# - 引用一个类/方法
- php - 如何从 JSON 数据中获取值作为数组