首页 > 技术文章 > 决策树

lspis 2021-12-06 22:13 原文

关键代码:

生成决策树:

def createDataSet():   # 创造示例数据
    col_type = [str, float, float, float, str, str, str];
    dataSet = []  # 数据源列表
    with open("pd1_table.csv") as f:
        f_csv = csv.reader(f)

        # 获得headers
        headers = next(f_csv)



        # init the row from each row
        for row in f_csv:
            data = []
            row = tuple(convert(value) for convert, value in zip(col_type, row))
           #print(row)
            if (row[1] <0.21*row[2]):
                data.append("0")
            else:
                data.append("1")

            if (row[1] < row[3]):
                data.append("0")
            else:
                data.append("1")

            if (row[4] > row[5]):
                data.append("0")
            else:
                data.append("1")

            if(row[6]=="1"):
                data.append("no")
            else:
                data.append("yes")
            # 0错  1对
            dataSet.append(data)
    label = ['销税-购税', '销税-票税', '开业日期-开票日期']
    return dataSet,label

def calcShannonEnt(dataSet):  #计算信息熵
    numEntries=len(dataSet)  #数据条数
    labelCounts={}
    for featVec in dataSet:
        currentLabel=featVec[-1] #每行数据的最后一个字
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1  #统计有多少个类以及每个类的数量
    shannonEnt=0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries  #计算单个类的熵值
        shannonEnt-=prob*log(prob,2)  #累加每个类的熵值
    return shannonEnt

def splitDataSet(dataSet,axis,value):   # 按某个特征分类后的数据
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):  # 选择最优的分类特征
    numFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)  # 原始的熵
    bestInfoGain = 0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob =len(subDataSet)/float(len(dataSet))
            newEntropy +=prob*calcShannonEnt(subDataSet)  # 按特征分类后的熵
        infoGain = baseEntropy - newEntropy  # 原始熵与按特征分类后的熵的差值
        if (infoGain>bestInfoGain):   # 若按某特征划分后,熵值减少的最大,则次特征为最优分类特征
            bestInfoGain=infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):    #按分类后类别数量排序,比如:最后分类为2男1女,则判定为男;
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]  # 类别:男或女
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet) #选择最优特征
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}} #分类结果以字典形式保存
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
                            (dataSet,bestFeat,value),subLabels)
    return myTree

def getDataSet():   # 创造测试数据
    col_type = [str, float, float, float, str, str, str];
    dataSet = []  # 数据源列表
    with open("pd_table.csv") as f:
        f_csv = csv.reader(f)

        # 获得headers
        headers = next(f_csv)



        # init the row from each row
        for row in f_csv:
            data = []
            row = tuple(convert(value) for convert, value in zip(col_type, row))
           #print(row)
            if (row[1] <100*row[2]):
                data.append('0')
            else:
                data.append('1')

            if (row[1] < row[3]):
                data.append('0')
            else:
                data.append('1')

            if (row[4] > row[5]):
                data.append('0')
            else:
                data.append('1')

            if(row[6]=="1"):
                data.append('no')
                # 0错  1对
                #dataSet.append(data)
            else:
                data.append('yes')
                # 0错  1对
                #dataSet.append(data)
            dataSet.append(data)
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    label = ["销税-购税", "销税-票税", "开业日期-开票日期"]

    # list=[]
    # for i in dataSet:
    #     dirc={}
    #     dirc["销税-购税"]=i[0]
    #     dirc["销税-票税"]=i[1]
    #     dirc["开业日期-开票日期"]=i[2]
    #     list.append(dirc)

    return dataSet

def adjust():
    num1=0 #yes
    num2=0 #no
    dataSet, label = createDataSet()  # 创造示列数据
    tree=createTree(dataSet,label)
    list = getDataSet()
    for i in list:
        key=list(tree.keys())[0]
        value=i[key]
        while 1:
            if value=='yes':
                num1+=1
                break
            if value=='no':
                num2+=1
                break

            if list(value.keys())[0]=='1':
                value=value["1"]
对测试集进行分类
def classify(Tree, featnames, X):
    classLabel=''
    root = list(Tree.keys())[0]
    firstDict = Tree[root]
    featindex = featnames.index(root)  #根节点的属性下标
    #classLabel='0'
    for key in firstDict.keys():   #根属性的取值,取哪个就走往哪颗子树
        if X[featindex] == key:
            if type(firstDict[key]) == type({}):
                classLabel = classify(firstDict[key],featnames,X)
            else:
                classLabel = firstDict[key]
    return classLabel

推荐阅读