首页 > 解决方案 > 在py中打印决策树

问题描述

我有一个class Node

class Node:
def __init__(self, label):
    self.label = label
    self.children = []
    self.answer = ''
    
    self.isPruned = False

其中label设置为属性eg( Outlook)。

children(value, Node)是一个数组,我在其中存储一对value分割决策,并且Node下一个节点 answer包含路径末尾的分类

我想将它打印为一棵树,在每个边缘上都会有value分裂的。

def print_tree(node, level = 0):
    if node.answer != '':
        print(' '*level, node.answer)
        return
    print(''*level, node.label)
    for value, n in node.children:
        print(''*(level+1), value)
        print_tree(n, level+2)
  

这就是我想出的,但不是我所期望的。

关于如何轻松显示它的任何想法,在正确的位置显示每个节点和边缘?

在此处输入图像描述

编辑:

这是生成树的代码:

def build_tree(df, attributes, target, parent = None):
    
    if df.empty:
        node = Node('')
        node.is_leaf = True
        node.answer = plurality_value(df[target])
        return node
    
    elif len(np.unique(df[target])) <= 1: #Tutti gli esempi sono uguali ritorna quella classificazione
        node = Node('')
        node.is_leaf = True
        node.answer = np.unique(df[target])[0]
        return node
    
    elif len(attributes) == 0:
        node = Node('')
        node.is_leaf = True
        node.answer = plurality_value(parent)
        return node
    
    else:
        best_split = importance(df, target)
        tree = Node(best_split)
        tree.is_internal = True
        
        new_attributes = [i for i in attributes if i != best_split]
        new_parent = df[target]
        
        for value in np.unique(df[best_split]):
            min_df = df.where(df[best_split] == value).dropna()
            sub_tree = build_tree(min_df, new_attributes, target, new_parent)
            tree.children.append((value, sub_tree))

        return tree

标签: pythongraphtreebinary-search-treegraphviz

解决方案


推荐阅读