首页 > 解决方案 > 数组列中所有元素的总和

问题描述

我是 spark 新手,并且有一个用例来查找列中所有值的总和。每列都是一个整数数组。

df.show(2,false)

+------------------+
|value             |
+------------------+
|[3,4,5]           |
+------------------+
|[1,2]             |
+------------------+ 

要找到的值 3 + 4 + 5 + 1 + 2 = 15

有人可以帮助/指导我如何实现这一目标吗?

编辑:我必须在 spark 2.3 中运行此代码

标签: dataframeapache-sparkapache-spark-sql

解决方案


一种选择是array对每一行求和,然后计算总和。这可以通过aggregateSpark 2.4.0 版提供的 Spark SQL 函数来完成。

val tmp = df.withColumn("summed_val",expr("aggregate(val,0,(acc, x) -> acc + x)"))

tmp.show()
+---+---------+----------+
| id|      val|summed_val|
+---+---------+----------+
|  1|[3, 4, 5]|        12|
|  2|   [1, 2]|         3|
+---+---------+----------+

//one row dataframe with the overall sum. collecting to a scalar value is possible too.
tmp.agg(sum("summed_val").alias("total")).show() 
+-----+
|total|
+-----+
|   15|
+-----+

另一种选择是使用explode. 但请注意,这种方法会产生大量需要汇总的数据。

import org.apache.spark.sql.functions.explode
val tmp = df.withColumn("elem",explode($"val"))
tmp.agg(sum($"elem").alias("total")).show()

推荐阅读