首页 > 解决方案 > 当 max_depth 为 1 时,sklearn DecsionTreeClassifier 如何选择输出值?

问题描述

这是我的代码

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

dataset = load_iris()
X_train,X_test,y_train,y_test = train_test_split(dataset.data,dataset.target,test_size=0.3)


reg = DecisionTreeClassifier(max_depth=1)
reg.fit(X_train,y_train)
print(reg.predict(X_test))

在此处输入图像描述

我已经为训练集添加了树的图像,在这里你可以看到在错误的情况下,数据集的值 [0,39,38]分别指向 0、1、2 的输出。因此,从假数据集 1 成为输出的可能性最高。决策树应该根据树对 0 或 1 进行分类,但我也可以在预测中看到 2。那么,sklearn 如何在什么条件下选择假集上的类来预测输出。

标签: pythonmachine-learningscikit-learndata-science

解决方案


我敢肯定,差异可能是因为没有设置random_state.

这里有两个随机性的地方,

  • 训练测试拆分
  • 构建决策树模型

您可能已经使用决策树进行了预测,然后使用另一个决策树创建了可视化。

尝试以下具有不同random_state值的代码:

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import plot_tree

dataset = load_iris()

X_train,X_test,y_train,y_test = train_test_split(dataset.data,
                                                 dataset.target,
                                                 test_size=0.3,
                                                 random_state=0)
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(max_depth=1, random_state=1)
clf.fit(X_train,y_train)
print(clf.predict(X_test))

plot_tree(clf)

在此处输入图像描述

注意:您需要 sklearn 版本 0.21.2 才能获得plot_tree功能。


推荐阅读