machine-learning - 最大递归深度错误(决策树代码)
问题描述
我为决策树编写了以下代码-
import csv
from google.colab import files
import numpy as np
class DecisionTree():
def entropy(self,set):#entropy calculation
p0=p1=0
for r in set:
if float(r[-1])==0:
p0=p0+1
elif float(r[-1])==1:
p1=p1+1
if p0!=0 and p1!=0:
return -((p0/(p0+p1))*np.log2(p0/(p0+p1))+(p1/(p0+p1))*np.log2(p1/(p0+p1)))
else:
return 0
def info_gain(self,parent,c1,c2):#information gain calculation
return self.entropy(parent)-(self.entropy(c1))*len(c1)/len(parent)-(self.entropy(c2))*len(c2)/len(parent)
def split_left(self,thresh,index,data_set):#left child
left=[]
for parent in data_set:
if float(parent[index])<float(thresh):
left.append(parent)
return left
def split_right(self,thresh,index,data_set):#right child
right=[]
for parent in data_set:
if float(parent[index])>=float(thresh):
right.append(parent)
return right
def best_split(self,training_set):
max_info_gain=-(np.inf)
corres_ind=0
thresh_max=0
for i in range(0,len(training_set[0])-1):
thresh=[]
for j in range(0,len(training_set)):
thresh.append(float(training_set[j][i]))
features=np.linspace(min(thresh),max(thresh),10)
for l in range(0,len(features)):
left=self.split_left(features[l],i,training_set)
right=self.split_right(features[l],i,training_set)
info_gain=self.info_gain(training_set,left,right)
if float(info_gain)>float(max_info_gain):
thresh_max=features[l]
corres_ind=i
max_info_gain=info_gain
return [thresh_max,corres_ind]
tree = {}
def learn(self, training_set):
# implement this function
self.tree = {}
if self.entropy(training_set)!=0 and len(training_set)>=2:
self.tree[0]=(self.best_split(training_set))[0]
self.tree[1]=(self.best_split(training_set))[1]
thre=self.tree[0]
ind=self.tree[1]
left_tree=DecisionTree()
right_tree=DecisionTree()
left=self.split_left(thre,ind,training_set)
right=self.split_right(thre,ind,training_set)
left_tree.learn(left)
right_tree.learn(right)
self.tree[2]=left_tree.tree
self.tree[3]=right_tree.tree
self.tree[4]=None
else:
self.tree[3]=None
self.tree[2]=None
self.tree[1]=None
self.tree[0]=None
c0=0
c1=0
for instance in training_set:
#print(instance)
if float(instance[-1])==0:
c0=c0+1
elif float(instance[-1])==1:
c1=c1+1
#print(c0,c1)
if c0>=c1:
self.tree[4]=0
elif c1>c0:
self.tree[4]=1
return self.tree
# implement this function
def classify(self, test_instance,x):
#result = 0 # baseline: always classifies as 0
if (x[4]!=None):
#print('called')
return x[4]
else:
value_rec=test_instance[x[1]]
if float(value_rec)<float(x[0]):
#print(float(value_rec),float(x['thresh_val']))
#print(x['left'])
return self.classify(test_instance,x[2])
elif float(value_rec)>=float(x[0]):
return self.classify(test_instance,x[3])
#return result
def run_decision_tree():
f=files.upload()
# Load data set
with open("spa.csv") as f:
next(f, None)
data = [tuple(line) for line in csv.reader(f, delimiter=",")]
print("Number of records: %d" % len(data))
# Split training/test sets
# You need to modify the following code for cross validation.
accuracies=[]
for K in range(0,10):
training_set = [x for i, x in enumerate(data) if i % 10 != K]
test_set = [x for i, x in enumerate(data) if i % 10== K]
tree = DecisionTree()
# Construct a tree using training set
dic=tree.learn( training_set )
print(dic)
#print(dic['left']['left']['left']['value'])
# Classify the test set using the tree we just constructed
results = []
for instance in test_set:
result = tree.classify( instance,dic )
#print(result)
results.append( result == float(instance[-1]))
# Accuracy
accuracy = float(results.count(True))/float(len(results))
print("accuracy: %.4f" % accuracy)
accuracies.append(accuracy)
if __name__ == "__main__":
run_decision_tree()
这适用于我的一个数据集。对于另一个数据集,它适用于大约 1/5 的记录,但是当我尝试使用整个数据集时,它说超出了最大递归深度。任何关于我哪里出错的指示都将受到高度赞赏。谢谢!
解决方案
推荐阅读
- javascript - 在 JavaScript 中插入 MySQLi
- ios - 如何从 NativeScript 调用 Objective-C NSExpression(format: ....)?
- c# - 关闭主窗体时最大化
- performance - Vugen 运行时设置不完整
- ios - 如何在代码中的任何地方使用 print() 会以某种方式强制加载 libswiftSwiftOnoneSupport.dylib
- react-native - 将 DateTimePicker 与第一次使用的时间反应
- elasticsearch - 由于多个实例,Logstash 无法启动,即使没有运行它的实例
- c - 如何使用C中的结构和递归找到数组中的最大值和最小值
- javascript - Apollo React 客户端 - 查询组件不会在第一次道具/变量更改时到达后端 - 在所有后续更改中都会执行
- javascript - 如何删除Javascript中的所有对象引用?