scala - 在窗口数据帧上计算四分位数
问题描述
我有一些数据,为了讨论起见,将其提供为:
val schema = Seq("id", "day", "value")
val data = Seq(
(1, 1, 1),
(1, 2, 11),
(1, 3, 1),
(1, 4, 11),
(1, 5, 1),
(1, 6, 11),
(2, 1, 1),
(2, 2, 11),
(2, 3, 1),
(2, 4, 11),
(2, 5, 1),
(2, 6, 11)
)
val df = sc.parallelize(data).toDF(schema: _*)
我想在一个移动的天窗内计算每个 ID 的四分位数。就像是
val w = Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quartiles"))
当然,这里没有四分位数功能,所以我需要编写一个UserDefinedAggregateFunction
. 以下是一个简单(尽管不可扩展)的解决方案(基于此)CollectionFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class QuartilesFunction extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", DoubleType, false) :: Nil)
def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)
override def dataType: DataType = ArrayType(DoubleType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
(lower++upper).splitAt((lower.length+upper.length)/2)
}
def sorted_median(x : IndexedSeq[Double]) : Option[Double] = {
if(x.length == 0) {
None
}
val N = x.length
val (lower, upper) = x.splitAt(N/2)
Some(
if(N%2==0) {
(lower.last+upper.head)/2.0
} else {
upper.head
}
)
}
// this is how to update the buffer given an input
def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
val value = input.getAs[Double](0)
if(lower.length == 0) {
buffer(0) = Array(value)
} else {
if(value >= lower.last) {
buffer(1) = (value +: upper).sortWith(_<_)
} else {
buffer(0) = (lower :+ value).sortWith(_<_)
}
}
val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
buffer(0) = result0
buffer(1) = result1
}
// this is how to merge two objects with the buffer schema type
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
buffer1(0) = result0
buffer1(1) = result1
}
def evaluate(buffer: Row): Array[Option[Double]] = {
val lower =
if (buffer(0) == null) {
IndexedSeq[Double]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Double]]
}
val upper =
if (buffer(1) == null) {
IndexedSeq[Double]()
} else {
buffer(1).asInstanceOf[IndexedSeq[Double]]
}
val Q1 = sorted_median(lower)
val Q2 = if(upper.length==0) { None } else { Some(upper.head) }
val Q3 = sorted_median(upper)
Array(Q1,Q2,Q3)
}
}
但是,执行以下操作会产生错误:
val quartiles = new QuartilesFunction
df.select('*).show
val w = org.apache.spark.sql.expressions.Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
val x = df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quantiles"))
x.show
错误是:
org.apache.spark.SparkException: Task not serializable
冒犯的功能似乎是sorted_median
。如果我将代码替换为:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class QuartilesFunction extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", DoubleType, false) :: Nil)
def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)
override def dataType: DataType = ArrayType(DoubleType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
(lower++upper).splitAt((lower.length+upper.length)/2)
}
/*
def sorted_median(x : IndexedSeq[Double]) : Option[Double] = {
if(x.length == 0) {
None
}
val N = x.length
val (lower, upper) = x.splitAt(N/2)
Some(
if(N%2==0) {
(lower.last+upper.head)/2.0
} else {
upper.head
}
)
}
*/
// this is how to update the buffer given an input
def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
val value = input.getAs[Double](0)
if(lower.length == 0) {
buffer(0) = Array(value)
} else {
if(value >= lower.last) {
buffer(1) = (value +: upper).sortWith(_<_)
} else {
buffer(0) = (lower :+ value).sortWith(_<_)
}
}
val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
buffer(0) = result0
buffer(1) = result1
}
// this is how to merge two objects with the buffer schema type
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
buffer1(0) = result0
buffer1(1) = result1
}
def evaluate(buffer: Row): Array[Option[Double]] = {
val lower =
if (buffer(0) == null) {
IndexedSeq[Double]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Double]]
}
val upper =
if (buffer(1) == null) {
IndexedSeq[Double]()
} else {
buffer(1).asInstanceOf[IndexedSeq[Double]]
}
val Q1 = Some(1.0)//sorted_median(lower)
val Q2 = Some(2.0)//if(upper.length==0) { None } else { Some(upper.head) }
val Q3 = Some(3.0)//sorted_median(upper)
Array(Q1,Q2,Q3)
}
}
然后一切正常,除了它不计算四分位数(显然)。我不明白这个错误,堆栈跟踪的其余部分不再具有启发性。有人可以帮助我了解问题所在和/或如何计算这些四分位数吗?
解决方案
如果您有配置单元上下文(或hiveSupportEnabled
),您可以使用percentile
UDAF,如下所示:
val dfQuartiles = df.select(
col("id"),
col("day"),
collect_list(col("value")).over(w).as("values"),
callUDF("percentile", col("value"), lit(0.25)).over(w).as("Q1"),
callUDF("percentile", col("value"), lit(0.50)).over(w).as("Q2"),
callUDF("percentile", col("value"), lit(0.75)).over(w).as("Q3"),
callUDF("percentile", col("value"), lit(1.0)).over(w).as("Q4")
)
或者,您可以使用 UDF 来计算四分位数values
(因为无论如何您都有这个数组):
val calcPercentile = udf((xs:Seq[Int], percentile:Double) => {
val ss = xs.toSeq.sorted
val index = ((ss.size-1)*percentile).toInt
ss(index)
}
)
val dfQuartiles = df.select(
col("id"),
col("day"),
collect_list(col("value")).over(w).as("values")
)
.withColumn("Q1",calcPercentile($"values",lit(0.25)))
.withColumn("Q2",calcPercentile($"values",lit(0.50)))
.withColumn("Q3",calcPercentile($"values",lit(0.75)))
.withColumn("Q4",calcPercentile($"values",lit(1.00)))
推荐阅读
- google-apps-script - 在对象电子表格中找不到函数 setValues
- python - 如何在 aiohttp 中发送浮点值作为后有效载荷?
- android - 在 Android 中使用 SOAP 从服务器下载大文件?我正在使用 ksoap2 库。大于 30MB 的文件会引发 OutOfMemoryError
- angular - 将html代码传递给php时出现对象HTMLDivElement - Angular
- prometheus - 如果在过去 30 分钟内未收到任何消息,则发出警报
- tensorflow - CNN 训练中的最佳权重更新
- javascript - 阻止来源为“http://localhost:63830”的框架访问跨域框架
- kubernetes - Kubernetes:无法从不同的命名空间访问 mongodb 副本集服务
- sap - 如何使用 SAP 数据服务从带有输入参数的 HANA 计算视图中提取数据,接收错误
- angular5 - 如何从angular5中的资产/图像文件夹中随机读取图像