首页 > 技术文章 > 强大的Scala模式匹配

barrywxx 2019-05-11 23:08 原文

用过Scala的模式匹配,感觉Java的弱爆了。Scala几乎可以匹配任何数据类型,如果默认的不能满足你的要求,你可以自定义模式匹配。

介绍Scala的模式匹配前,我们先了解清楚unapply()与unapplySeq()两个方法:

名字叫做unapply和unapplySeq的方法在Scala里也是有特殊含义的。

我们前面说过case class在做pattern match时非常好用,而除case class之外,有unapply或unapplySeq方法的对象在pattern match时也有非常好的应用场景。

比方这段代码:

1
2
3
object Square {
  def unapply(z: Double): Option[Double] = Some(math.sqrt(z))
}

我们定义了一个unapply方法,用来计算平方根。

我们能够像调用普通方法一样的调用它:

1
2
val number: Double = 36.0
Square.unapply(number)

这样会得到36的平方根:6。实际上返回值是Some(6)。

上面的方式是对unapply的浪费。unapply真正的优点是这种:

1
2
3
4
5
val number: Double = 36.0
number match {
  case Square(n) => println(s"square root of $number is $n")
  case _ => println("nothing matched")
}

这样我们无需显式调用unapply方法,而把是它用在pattern match中。让编译器替我们调用它。

当我们写下这段pattern match的代码时,编译器事实上替我们做了好几件事:

  1. 调用unapply,传入number
  2. 接收返回值并推断返回值是None,还是Some
  3. 假设是Some,则将其解开,并将当中的值赋值给n(就是case Square(n)中的n)

这段代码反编译出来是这个样子的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
  double number = 36.0D;
  double d1 = number;
  Option localOption = Square..MODULE$.unapply(d1);
  //调用unapply,传入number
  BoxedUnit localBoxedUnit;
  if (localOption.isEmpty()) {//推断返回值是None
    Predef..MODULE$.println("nothing matched");
    localBoxedUnit = BoxedUnit.UNIT;
  }
  else {//推断返回值是Some
    double n = BoxesRunTime.unboxToDouble(localOption.get());
    //将Some解开,并将当中的值赋值给n
    Predef..MODULE$.println(new StringContext(Predef..MODULE$.wrapRefArray((Object[]) new String[] {
      "square root of ", " is ", ""
    }) ).s(Predef..MODULE$.genericWrapArray(new Object[] {
      BoxesRunTime.boxToDouble(number), BoxesRunTime.boxToDouble(n)
    })));
    localBoxedUnit = BoxedUnit.UNIT;
  }

假设没有unapply方法和pattern match语法之间的这样的结合,我们自己写代码要写成什么样子呢?

也许会比上面反编译的代码简单一些,可是显式地调用开平方的方法。用if else来推断Option,以及将真正的返回值从Option里面解出来这三件事是免不掉的。

unapplySeq和unapply的作用非常是类似,比如这样:

1
2
3
4
5
6
object Names {
  def unapplySeq(str: String): Option[Seq[String]] = {
    if (str.contains(",")) Some(str.split(","))
    else None
  }
}

我们定义一个unapplySeq方法,用逗号作为分隔符来把字符串拆开。

然后我们能够这样应用它:

1
2
3
4
5
6
7
8
val namesString = "xiao ming,xiao hong,tom"
namesString match {
  case Names(first, second, third) => {
    println("the string contains three people's names")
    println(s"$first $second $third")
  }
  case _ => println("nothing matched")
}

与上面的样例非常是类似,只是编译器在这里替我们做的事情很多其它了:

  1. 调用unapplySeq,传入namesString
  2. 接收返回值并推断返回值是None,还是Some
  3. 假设是Some,则将其解开
  4. 推断解开之后得到的sequence中的元素的个数是否是三个
  5. 假设是三个,则把三个元素分别取出,赋值给first,second和third

假设没有unapplySeq方法和pattern match语法之间的这样的结合,我们自己写代码来做这五件事会显得非常是繁琐。

 

大家了解清楚unapply()与unapplySeq()两个方法,我就举个很实用的例子。

现实中往往我们有这样的一个需求:比如http请求返回体为json字符串,我们往往只需要获取json中的部分字段值。

这个需求,Java的做法是,必须知道所有json中所有字段及类型,并定义对应的JavaBean,将返回的json字符串先转换为对应的JavaBean,再通过javaBean获取想要的字段信息。

那么Scala使用模式匹配就很轻松搞定直接获取想要的字段值。如下:

先定义一个unapplySeq()实现json的模式匹配(另json字符串插值写法可参考我以前文章:Scala字符串插值 - StringContext):


package spray.json

import scala.collection.SortedMap
class JsonInterpolation(sc: StringContext) {
  object json {
    def apply(args: JsValue*): JsValue =
      new JsonParser(ParserInput(sc, args), true).parseJsValue()

    def unapplySeq(input: JsValue): Option[Seq[JsValue]] = {

      val placeHolders = Seq.range(0, sc.parts.length-1).map(x => JsNumber(Integer.MAX_VALUE - x) )

      val pi = ParserInput(sc, placeHolders)
      val pattern = new JsonParser(pi, true).parseJsValue()

      val results = collection.mutable.ArrayBuffer[JsValue]()
      Seq.range(0, sc.parts.length-1).foreach { x => results += null }

      try {
        patternMatch(pattern, input, placeHolders, results)
        Some(results.toSeq)
      }
      catch {
        case ex: Throwable => None
      }

    }

    // TODO report friendly
    private def patternMatch(pattern: JsValue, input: JsValue, placeHolders: Seq[JsValue], results: collection.mutable.ArrayBuffer[JsValue]): Unit = {

      def isPlaceHolder(value: JsNumber) = {
        val num = value.value.toInt
        val index = Integer.MAX_VALUE - num.toInt
        num > 0 && index < placeHolders.size && placeHolders(index).eq(value)
      }

      pattern match {
        case x: JsObject =>
          x.fields.foreach {
            case (key, n @ JsNumber(num)) if isPlaceHolder(n) =>
              val index = Integer.MAX_VALUE - num.toInt
              assert(input.asJsObject.fields contains key)
              results(index) = input.asJsObject.fields(key)

            case (key, value) =>
              assert(input.asJsObject.fields contains key)
              patternMatch(value, input.asJsObject.fields(key), placeHolders, results)
          }
        case x: JsArray =>
          assert(input.isInstanceOf[JsArray])
          assert(input.asInstanceOf[JsArray].elements.size >= x.elements.size)
          x.elements.zipWithIndex.foreach {
            case (x: JsNumber, y: Int) if isPlaceHolder(x) =>
              val index = Integer.MAX_VALUE - x.value.toInt
              results(index) = input.asInstanceOf[JsArray].elements(y)
            case (x: JsValue,y: Int)=>
              patternMatch(x, input.asInstanceOf[JsArray].elements.apply(y), placeHolders, results)
          }
        case x: JsString =>
          assert(x == input)
        case x: JsBoolean =>
          assert(x == input)
        case x: JsNumber =>
          assert(x == input)
        case x@ JsNull =>
          assert(x == input)
      }
    }
  }
}

 

定义好unpplySeq(),那么下面就看下如何通过模式匹配获取json字符传中我们关系的字段值:

import spray.json._                                             
                                                                
/**                                                             
  * 类功能描述:                                                      
  *                                                             
  * @author WangXueXing create at 19-5-11 下午10:37               
  * @version 1.0.0                                              
  */                                                            
object ObjectMatch {                                            
  def main(args: Array[String]): Unit = {                       
    val json = json"""{"id": 12333,"sku_no":"sku35352523"}"""   
    json match {                                                
      case json"""{"sku_no": $skuNo}""" =>  println(skuNo)      
      case _ => println(json)                                   
    }                                                           
                                                                
  }                                                             
}                                                               

输出:

"sku35352523"

如上,我们就可以选择性的拿到sku_no字段的值,是不是轻松搞定!

 

推荐阅读