首页 > 解决方案 > 如何对列表元素的类型进行模式匹配

问题描述

我想根据对象的类型对对象列表进行模式匹配。但是将模式指定为case x: List[ObjectType]似乎不起作用。

以这个程序为例。

sealed trait A
case class B() extends A
case class C() extends A

def func(theList: List[A]) = theList match
{
    case listOfB: List[B] => println("All B's")
    case listOfC: List[C] => println("All C's")
    case _ => println("Somthing else")
}

func(List(C(), C(), C())) // prints: "All B's"

虽然列表只包含 C 并且 case 模式指定了 B 的列表,但 match 语句将其识别为 B 的列表?

我知道我可以像这样检查列表的每个元素:

case listOfA: List[A] if listOfA.forall{case B() => true case _ => false} => println("All B's")

listOfA.asInstanceOf[List[B]]但它比较麻烦,当我尝试使用它时,我必须指定它确实是 B 的 ( ) 列表。

我怎样才能以更聪明/更好的方式做到这一点?

标签: scalapattern-matching

解决方案


尝试自定义提取器以减少模式匹配的繁琐

object AllB {
  def unapply(listOfA: List[A]): Boolean = 
    listOfA.forall { case B() => true; case _ => false }
}
object AllC {
  def unapply(listOfA: List[A]): Boolean = 
    listOfA.forall { case C() => true; case _ => false }
}

def func(theList: List[A]) = theList match {
  case AllB() => println("All B's")
  case AllC() => println("All C's")
  case _      => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

或者

import cats.implicits._

object AllB {
  def unapply(listOfA: List[A]): Option[List[B]] = 
    listOfA.traverse { case b@B() => Some(b); case _ => None }
}
object AllC {
  def unapply(listOfA: List[A]): Option[List[C]] = 
    listOfA.traverse { case c@C() => Some(c); case _ => None }
}

def func(theList: List[A]) = theList match {
  case AllB(listOfB) => println("All B's")
  case AllC(listOfC) => println("All C's")
  case _             => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

或者您可以定义一个类来创建所有必要的提取器并删除代码重复

class All[SubT: ClassTag] {
  def unapply[T >: SubT](listOfA: List[T]): Option[List[SubT]] = 
    listOfA.traverse { case x: SubT => Some(x); case _ => None }
}

object AllB extends All[B]
object AllC extends All[C]
// val AllB = new All[B]
// val AllC = new All[C]

def func(theList: List[A]) = theList match {
  case AllB(listOfB) => println("All B's")
  case AllC(listOfC) => println("All C's")
  case _             => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

我猜,最简单的就是使用 Shapeless

import shapeless.TypeCase

val AllB = TypeCase[List[B]]
val AllC = TypeCase[List[C]]

def func(theList: List[A]) = theList match {
  case AllB(listOfB) => println("All B's")
  case AllC(listOfC) => println("All C's")
  case _             => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

https://github.com/milessabin/shapeless/wiki/Feature-overview:-shapeless-2.0.0#type-safe-cast

在 Shapeless 类型Typeable中定义了类。只是它的列表实例的定义比@LuisMiguelMejíaSuárez的答案更棘手(即,使用运行时反射)

/** Typeable instance for `Traversable`.    
 *  Note that the contents be will tested for conformance to the element type. */  
implicit def genTraversableTypeable[CC[X] <: Iterable[X], T]
  (implicit mCC: ClassTag[CC[_]], castT: Typeable[T]): Typeable[CC[T] with Iterable[T]] =
  // Nb. the apparently redundant `with Iterable[T]` is a workaround for a
  // Scala 2.10.x bug which causes conflicts between this instance and `anyTypeable`.
  new Typeable[CC[T]] {
    def cast(t: Any): Option[CC[T]] =
      if(t == null) None
      else if(mCC.runtimeClass isInstance t) {
        val cc = t.asInstanceOf[CC[Any]]
        if(cc.forall(_.cast[T].isDefined)) Some(t.asInstanceOf[CC[T]])
        else None
      } else None
    def describe = s"${safeSimpleName(mCC)}[${castT.describe}]"
  }

https://github.com/milessabin/shapeless/blob/master/core/src/main/scala/shapeless/typeable.scala#L235-L250

另请参阅Scala 中模式匹配泛型类型的方法 https://gist.github.com/jkpl/5279ee05cca8cc1ec452fc26ace5b68b


推荐阅读