首页 > 解决方案 > Spark UDF 没有正确给出滚动计数

问题描述

我有一个 Spark UDF 来精确地计算列的滚动计数。如果我需要计算 24 小时的滚动计数,例如对于时间为 2020-10-02 09:04:00 的条目,我需要回顾到 2020-10-01 09:04:00(非常精确)。

如果我在本地运行,滚动计数 UDF 工作正常并给出正确的计数,但是当我在集群上运行时,它给出的结果不正确。这是示例输入和输出

输入

+---------+-----------------------+
|OrderName|Time                   |
+---------+-----------------------+
|a        |2020-07-11 23:58:45.538|
|a        |2020-07-12 00:00:07.307|
|a        |2020-07-12 00:01:08.817|
|a        |2020-07-12 00:02:15.675|
|a        |2020-07-12 00:05:48.277|
+---------+-----------------------+

预期产出

+---------+-----------------------+-----+
|OrderName|Time                   |Count|
+---------+-----------------------+-----+
|a        |2020-07-11 23:58:45.538|1    |
|a        |2020-07-12 00:00:07.307|2    |
|a        |2020-07-12 00:01:08.817|3    |
|a        |2020-07-12 00:02:15.675|1    |
|a        |2020-07-12 00:05:48.277|1    |
+---------+-----------------------+-----+

最后两个条目值在本地是 4 和 5,但在集群上它们是不正确的。我最好的猜测是数据正在跨执行器分布,并且 udf 也在每个执行器上并行调用。由于 UDF 的参数之一是列(本示例中的分区键 - OrderName),如果是这种情况,我如何控制/纠正集群的行为。以便它以正确的方式计算每个分区的正确计数。请有任何建议

标签: scalaapache-sparkuser-defined-functions

解决方案


根据您的评论,您想计算过去 24 小时内每条记录的总记录数

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.LongType

//A sample data (Guessing from your question)
val df = Seq(("a","2020-07-10 23:58:45.438","1"),("a","2020-07-11 23:58:45.538","1"),("a","2020-07-11 23:58:45.638","1")).toDF("OrderName","Time","Count")

// Extract the UNIX TIMESTAMP for your time column
val df2 = df.withColumn("unix_time",concat(unix_timestamp($"Time"),split($"Time","\\.")(1)).cast(LongType))

val noOfMilisecondsDay : Long = 24*60*60*1000

//Create a window per `OrderName` and select rows from `current time - 24 hours` to `current time` 
val winSpec = Window.partitionBy("OrderName").orderBy("unix_time").rangeBetween(Window.currentRow - noOfMilisecondsDay, Window.currentRow)

// Final you perform your COUNT or SUM(COUNT) as per your need
val finalDf = df2.withColumn("tot_count", count("OrderName").over(winSpec))

//or val finalDf = df2.withColumn("tot_count", sum("Count").over(winSpec))

推荐阅读