首页 > 解决方案 > spark sql-为sum函数提供列表参数

问题描述

我正在使用火花数据框 API。我试图给 sum() 一个列表参数,其中包含列名作为字符串。当我将列名直接放入函数时-脚本有效'当我尝试将其作为列表类型的参数提供给函数时-我收到错误:

"py4j.protocol.Py4JJavaError: An error occurred while calling o155.sum.
: java.lang.ClassCastException: java.util.ArrayList cannot be cast to java.lang.String"

对 groupBy() 使用相同类型的列表参数是有效的。这是我的脚本:

groupBy_cols = ['date_expense_int', 'customer_id']
agged_cols_list = ['total_customer_exp_last_m','total_customer_exp_last_3m']

df = df.groupBy(groupBy_cols).sum(agged_cols_list)

当我这样写它时,它可以工作:

df = df.groupBy(groupBy_cols).sum('total_customer_exp_last_m','total_customer_exp_last_3m')

我还尝试通过使用给 sum() 列列表

agged_cols_list2 = []
for i in agged_cols_list:
    agged_cols_list2.append(col(i))

也没有工作

标签: apache-sparkpysparkapache-spark-sql

解决方案


如果你有一个像下面这样的 df 并且想要总结一个字段列表

df.show(5,truncate=False)

+---+---------+----+
|id |subject  |mark|
+---+---------+----+
|100|English  |45  |
|100|Maths    |63  |
|100|Physics  |40  |
|100|Chemistry|94  |
|100|Biology  |74  |
+---+---------+----+

only showing top 5 rows

agged_cols_list=['subject', 'mark']

df.groupBy("id").agg(*[sum(col(c)) for c in agged_cols_list]).show(5,truncate=False)

+---+------------+---------+
|id |sum(subject)|sum(mark)|
+---+------------+---------+
|125|null        |330.0    |
|124|null        |332.0    |
|155|null        |304.0    |
|132|null        |382.0    |
|154|null        |300.0    |
+---+------------+---------+

请注意, sum(subject) 为 null,因为它是一个字符串列。在这种情况下,您可能希望将计数应用于主题并将总和应用于标记。所以你可以使用字典

summary={ "subject":"count","mark":"sum" }

df.groupBy("id").agg(summary).show(5,truncate=False)

+---+--------------+---------+
|id |count(subject)|sum(mark)|
+---+--------------+---------+
|125|5             |330.0    |
|124|5             |332.0    |
|155|5             |304.0    |
|132|5             |382.0    |
|154|5             |300.0    |
+---+--------------+---------+
only showing top 5 rows

推荐阅读