首页 > 技术文章 > 随机森林和决策树交叉验证的使用

oceaning 2021-05-23 22:02 原文


##导入包
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_wine

##导入数据集
wine=load_wine()

from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)
#决策树模型
clf=DecisionTreeClassifier(random_state=0)

#随机森林模型
rfc=RandomForestClassifier(random_state=0)

##训练集训练
clf=clf.fit(Xtrain,Ytrain)
rfc=rfc.fit(Xtrain,Ytrain)

##测试集测试效果得分
score_c=clf.score(Xtest,Ytest)
score_r=rfc.score(Xtest,Ytest)

print("Single Tree:{}".format(score_c)
,"Random forest:{}".format(score_r))

##交叉验证,
##将数据集划分为n份,依次取每一份做测试集,n-1份做训练集,多次训练模型以观测模型的稳定性

rfc_1=[]
clf_1=[]

for i in range(10):
rfc=RandomForestClassifier(n_estimators=25)
rfc_s=cross_val_score(rfc,wine.data,wine.target,cv=10).mean() ##cv=10,将数据集分为10份进行测试,数据集传的是完整数据集,不需要分割成测试和训练
rfc_1.append(rfc_s)
clf=DecisionTreeClassifier()
clf_s=cross_val_score(clf,wine.data,wine.target,cv=10).mean()
clf_1.append(clf_s)

plt.plot(range(1,11),rfc_1,label="Random Forest")
plt.plot(range(1,11),clf_1,label="Decision Tree")
plt.legend()
plt.show()

推荐阅读