首页 > 解决方案 > 如何在不知道 Spark 中的任何特定值的情况下根据 GroupBy 列的值选择列

问题描述

我有一些这样的数据:

val simpleData = Seq(
 ("James",   "Sales",     3000,  1, -1, -1),
 ("Michael", "Sales",     4600,  2, -1, -1),
 ("Robert",  "Sales",     4100,  3, -1, -1),
 ("Maria",   "Finance",   3000, -1,  1, -1),
 ("James",   "Sales",     3000,  4, -1, -1),
 ("Scott",   "Finance",   3300, -1,  2, -1),
 ("Jen",     "Finance",   3900, -1,  3, -1),
 ("Jeff",    "Marketing", 3000, -1, -1,  1),
 ("Kumar",   "Marketing", 2000, -1, -1,  2),
 ("Saif",    "Sales",     4100,  5, -1, -1)
)

DataFrame定义如下

val df = simpleData.toDF("employee_name", "department", "salary", "sales_no", "finance_no", "marketing_no")

我想像这样列出每个部门的所有员工编号:

+----------+------------------------+
|department|collect_list(finance_no)|
+----------+------------------------+
|Sales     |[1,2,3,4,5]             |
|Finance   |[1, 2, 3]               |
|Marketing |[1,2]                   |
+----------+------------------------+

我定义了一个这样的UDF,试图根据部门的值返回员工编号列,

val departmentToNo = (dept: String) => {
 dept match {
  case "Sales" => "sales_no"
  case "Finance" => "finance_no"
  case "Marketing" => "marketing_no"
 }
}
df.groupBy($"department")
  .agg(collect_list(udf(departmentToNo).apply($"department"))
  .as("newColumn"))

但我得到的是以下内容,

+----------+-----------------------------------------------------------+
|department|newColumn.                                                 |
+----------+-----------------------------------------------------------+
|Sales     |[sales_no,sales_no,sales_no,sales_no,sales_no]             |
|Finance   |[finance_no, finance_no, finance_no        ]               |
|Marketing |[marketing_no, marketing_no]                               |
+----------+-----------------------------------------------------------+

如何在 Spark SQL 中实现这一点?

标签: apache-sparkapache-spark-sql

解决方案


您可以使用 scala curring 功能来降低复杂性,请尝试以下代码片段:

import org.apache.spark.sql.functions.col
var column_to_compare = df.select("sales_no","finance_no","marketing_no").columns.map(x => col(x))
df.withColumn("summary", column_to_compare.reduce((c1,c2) => when(c1 > 0,c1).otherwise(c2))).groupBy("department").agg(collect_list("summary") as "all_cols_unique").show

推荐阅读