首页 > 解决方案 > 使用 Window() 计算 PySpark 中数组的滚动总和?

问题描述

我想计算给定 unix 时间戳的 ArrayType 列的滚动总和,并将其分组为 2 秒增量。示例输入/输出如下。我认为 Window() 函数会起作用,我对 PySpark 还很陌生,完全迷路了。非常感谢任何输入!

输入:

timestamp     vars 
2             [1,2,1,2]
2             [1,2,1,2]
3             [1,1,1,2]
4             [1,3,4,2]
5             [1,1,1,3]
6             [1,2,3,5]
9             [1,2,3,5]

预期输出:

+---------+-----------------------+
|timestamp|vars                   |
+---------+-----------------------+
|2        |[2.0, 4.0, 2.0, 4.0]   |
|4        |[4.0, 8.0, 7.0, 8.0]   |
|6        |[6.0, 11.0, 11.0, 16.0]|
|10       |[7.0, 13.0, 14.0, 21.0]|
+---------+-----------------------+

谢谢!

编辑:多列可以具有相同的时间戳/它们可能不连续。vars 的长度也可能 > 3。请寻找一个稍微通用的解决方案。

标签: apache-sparkpysparkapache-spark-sqlpyspark-dataframes

解决方案


对于 Spark 2.4+,您可以使用数组函数和高阶函数。此解决方案适用于不同的数组大小(如果每行之间的事件不同)。以下是解释的步骤:

vars首先,按 2 秒分组并在数组列中收集:

df = df.groupBy((ceil(col("timestamp") / 2) * 2).alias("timestamp")) \
       .agg(collect_list(col("vars")).alias("vars"))

df.show()

#+---------+----------------------+
#|timestamp|vars                  |
#+---------+----------------------+
#|6        |[[1, 1, 1], [1, 2, 3]]|
#|2        |[[1, 1, 1], [1, 2, 1]]|
#|4        |[[1, 1, 1], [1, 3, 4]]|
#+---------+----------------------+

vars在这里,我们将每个连续的 2 秒分组,并将数组收集到一个新列表中。现在,使用 Window 规范,您可以收集累积值并用于flatten展平子数组:

w = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, Window.currentRow)
df = df.withColumn("vars", flatten(collect_list(col("vars")).over(w)))
df.show()

#+---------+------------------------------------------------------------------+
#|timestamp|vars                                                              |
#+---------+------------------------------------------------------------------+
#|2        |[[1, 1, 1], [1, 2, 1]]                                            |
#|4        |[[1, 1, 1], [1, 2, 1], [1, 1, 1], [1, 3, 4]]                      |
#|6        |[[1, 1, 1], [1, 2, 1], [1, 1, 1], [1, 3, 4], [1, 1, 1], [1, 2, 3]]|
#+---------+------------------------------------------------------------------+

最后,使用aggregate函数zip_with对数组求和:

t = "aggregate(vars, cast(array() as array<double>), (acc, a) -> zip_with(acc, a, (x, y) -> coalesce(x, 0) + coalesce(y, 0)))"

df.withColumn("vars", expr(t)).show(truncate=False)

#+---------+-----------------+
#|timestamp|vars             |
#+---------+-----------------+
#|2        |[2.0, 3.0, 2.0]  |
#|4        |[4.0, 7.0, 7.0]  |
#|6        |[6.0, 10.0, 11.0]|
#+---------+-----------------+

放在一起:

from pyspark.sql.functions import ceil, col, collect_list, flatten, expr
from pyspark.sql import Window

w = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, Window.currentRow)
t = "aggregate(vars, cast(array() as array<double>), (acc, a) -> zip_with(acc, a, (x, y) -> coalesce(x, 0) + coalesce(y, 0)))"

nb_seconds = 2

df.groupBy((ceil(col("timestamp") / nb_seconds) * nb_seconds).alias("timestamp")) \
  .agg(collect_list(col("vars")).alias("vars")) \
  .withColumn("vars", flatten(collect_list(col("vars")).over(w))) \
  .withColumn("vars", expr(t)).show(truncate=False)

推荐阅读