首页 > 解决方案 > 在窗口数据帧上计算四分位数

问题描述

我有一些数据,为了讨论起见,将其提供为:

val schema = Seq("id", "day", "value")
val data = Seq(
    (1, 1, 1), 
    (1, 2, 11),
    (1, 3, 1), 
    (1, 4, 11),
    (1, 5, 1), 
    (1, 6, 11),
    (2, 1, 1), 
    (2, 2, 11),
    (2, 3, 1), 
    (2, 4, 11),
    (2, 5, 1), 
    (2, 6, 11) 
  )   

val df = sc.parallelize(data).toDF(schema: _*) 

我想在一个移动的天窗内计算每个 ID 的四分位数。就像是

val w = Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quartiles"))

当然,这里没有四分位数功能,所以我需要编写一个UserDefinedAggregateFunction. 以下是一个简单(尽管不可扩展)的解决方案(基于CollectionFunction

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class QuartilesFunction extends UserDefinedAggregateFunction {
    def inputSchema: StructType =
        StructType(StructField("value", DoubleType, false) :: Nil)

    def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)

    override def dataType: DataType = ArrayType(DoubleType, true)

    def deterministic: Boolean = true

    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = IndexedSeq[Double]()
        buffer(1) = IndexedSeq[Double]()
    }

  def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
    (lower++upper).splitAt((lower.length+upper.length)/2)
  }

  def sorted_median(x : IndexedSeq[Double]) : Option[Double] = { 
    if(x.length == 0) {
      None
    }
    val N = x.length
    val (lower, upper) = x.splitAt(N/2)
    Some(
            if(N%2==0) {
                (lower.last+upper.head)/2.0
            } else {
                upper.head
            }
    )
  }

  // this is how to update the buffer given an input
  def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
    val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
    val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
    val value = input.getAs[Double](0)
    if(lower.length == 0) {
      buffer(0) = Array(value)
    } else {
      if(value >= lower.last) {
        buffer(1) = (value +: upper).sortWith(_<_)
      } else {
        buffer(0) = (lower :+ value).sortWith(_<_)
      }
    }
    val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
    buffer(0) = result0
    buffer(1) = result1
  }

  // this is how to merge two objects with the buffer schema type
  def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
    buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
    buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
    val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
    buffer1(0) = result0
    buffer1(1) = result1
  }

    def evaluate(buffer: Row): Array[Option[Double]] = {
        val lower = 
        if (buffer(0) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(0).asInstanceOf[IndexedSeq[Double]]
        }
        val upper = 
        if (buffer(1) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(1).asInstanceOf[IndexedSeq[Double]]
        }
        val Q1 = sorted_median(lower)
        val Q2 = if(upper.length==0) { None } else { Some(upper.head) }
        val Q3 = sorted_median(upper)
        Array(Q1,Q2,Q3)
    }
}

但是,执行以下操作会产生错误:

val quartiles = new QuartilesFunction
df.select('*).show
val w = org.apache.spark.sql.expressions.Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
val x = df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quantiles"))
x.show

错误是:

org.apache.spark.SparkException: Task not serializable

冒犯的功能似乎是sorted_median。如果我将代码替换为:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class QuartilesFunction extends UserDefinedAggregateFunction {
    def inputSchema: StructType =
        StructType(StructField("value", DoubleType, false) :: Nil)

    def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)

    override def dataType: DataType = ArrayType(DoubleType, true)

    def deterministic: Boolean = true

    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = IndexedSeq[Double]()
        buffer(1) = IndexedSeq[Double]()
    }

  def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
    (lower++upper).splitAt((lower.length+upper.length)/2)
  }
/*
  def sorted_median(x : IndexedSeq[Double]) : Option[Double] = { 
    if(x.length == 0) {
      None
    }
    val N = x.length
    val (lower, upper) = x.splitAt(N/2)
    Some(
            if(N%2==0) {
                (lower.last+upper.head)/2.0
            } else {
                upper.head
            }
    )
  }
*/
  // this is how to update the buffer given an input
  def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
    val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
    val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
    val value = input.getAs[Double](0)
    if(lower.length == 0) {
      buffer(0) = Array(value)
    } else {
      if(value >= lower.last) {
        buffer(1) = (value +: upper).sortWith(_<_)
      } else {
        buffer(0) = (lower :+ value).sortWith(_<_)
      }
    }
    val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
    buffer(0) = result0
    buffer(1) = result1
  }

  // this is how to merge two objects with the buffer schema type
  def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
    buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
    buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
    val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
    buffer1(0) = result0
    buffer1(1) = result1
  }

    def evaluate(buffer: Row): Array[Option[Double]] = {
        val lower = 
        if (buffer(0) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(0).asInstanceOf[IndexedSeq[Double]]
        }
        val upper = 
        if (buffer(1) == null) {
                IndexedSeq[Double]()
        } else {
                buffer(1).asInstanceOf[IndexedSeq[Double]]
        }
        val Q1 = Some(1.0)//sorted_median(lower)
        val Q2 = Some(2.0)//if(upper.length==0) { None } else { Some(upper.head) }
        val Q3 = Some(3.0)//sorted_median(upper)
        Array(Q1,Q2,Q3)
    }
}

然后一切正常,除了它不计算四分位数(显然)。我不明白这个错误,堆栈跟踪的其余部分不再具有启发性。有人可以帮助我了解问题所在和/或如何计算这些四分位数吗?

标签: scalaapache-sparkapache-spark-sql

解决方案


如果您有配置单元上下文(或hiveSupportEnabled),您可以使用percentileUDAF,如下所示:

val dfQuartiles = df.select(
  col("id"),
  col("day"),
  collect_list(col("value")).over(w).as("values"),
  callUDF("percentile", col("value"), lit(0.25)).over(w).as("Q1"),
  callUDF("percentile", col("value"), lit(0.50)).over(w).as("Q2"),
  callUDF("percentile", col("value"), lit(0.75)).over(w).as("Q3"),
  callUDF("percentile", col("value"), lit(1.0)).over(w).as("Q4")
)

或者,您可以使用 UDF 来计算四分位数values(因为无论如何您都有这个数组):

val calcPercentile = udf((xs:Seq[Int], percentile:Double) => {
  val ss = xs.toSeq.sorted 
  val index = ((ss.size-1)*percentile).toInt
  ss(index)
} 
)

val dfQuartiles = df.select(
  col("id"),
  col("day"),
  collect_list(col("value")).over(w).as("values")
)
.withColumn("Q1",calcPercentile($"values",lit(0.25)))
.withColumn("Q2",calcPercentile($"values",lit(0.50)))
.withColumn("Q3",calcPercentile($"values",lit(0.75)))
.withColumn("Q4",calcPercentile($"values",lit(1.00)))

推荐阅读