首页 > 解决方案 > 使用 Scala 模式匹配实现 BST

问题描述

我是 scala 的新手,并试图在 scala 中使用模式匹配概念来实现 BST。

编辑:我已经修改了插入函数,现在它的行为符合预期,有人可以帮我让它尾递归吗?此外,任何其他代码改进将不胜感激。

trait IntTree {                                                                                                                                   
  def contains(v: Int): Boolean                                                                                                                   

  def insert(x: Int): IntTree                                                                                                                     
}                                                                                                                                                 

case object EmptyTree extends IntTree {                                                                                                           

  override def insert(x: Int): IntTree = Node(x, EmptyTree, EmptyTree)                                                                            

  override def contains(v: Int): Boolean = false                                                                                                  
}                                                                                                                                                 

case class Node(elem: Int, left: IntTree, right: IntTree) extends IntTree {                                                                       

  override def contains(v: Int): Boolean = {                                                                                                      
    @scala.annotation.tailrec                                                                                                                     
    def contains(t: IntTree, v: Int): Boolean = t match {                                                                                         
      case Node(data, _, _) if (data == v) => true                                                                                                
      case Node(data, l, r) => if (data > v) contains(l, v) else contains(r, v)                                                                   
      case _ => false                                                                                                                             
    }                                                                                                                                             

    contains(this, v)                                                                                                                             
  }                                                                                                                                               

  override def insert(x: Int): IntTree = {                                                                                                        
    def insert(t: IntTree, x: Int): IntTree = t match {                                                                                           
      case Node(data, l, r) if (data > x) => Node(data, insert(l, x), r)                                                                          
      case Node(data, l, r) if (data < x) => Node(data, l, insert(r, x))                                                                          
      case EmptyTree => t insert x                                                                                                                
      case _ => t                                                                                                                                 
    }                                                                                                                                             

    insert(this, x)                                                                                                                               
  }                                                                                                                                               
}                                      

标签: scala

解决方案


在您走下叶子之后,它需要重新访问和更新父节点:

sealed trait IntTree {
    def contains(v: Int): Boolean
    def insert(x: Int): Node // better to return Node here
}

def insert(x: Int): Node = {
    @annotation.tailrec
    def insert(t: IntTree, x: Int, parents: List[Node]): Node = t match {
        case EmptyTree =>
            parents.foldLeft(t insert x) { case (n, p) =>
                if (p.elem >= n.elem) p.copy(left = n)
                else p.copy(right = n)
            }
        case Node(data, l, r) =>
            insert(if(data >= x) l else r, x, t :: parents)
    }

    insert(this, x, List.empty)
}

推荐阅读