首页 > 解决方案 > 使用 Scala 在 Spark 中将数组转换为自定义字符串格式

问题描述

我创建了一个DataFrame如下:

import spark.implicits._
import org.apache.spark.sql.functions._
val df = Seq(
  (1, List(1,2,3)),
  (1, List(5,7,9)),
  (2, List(4,5,6)),
  (2, List(7,8,9)),
  (2, List(10,11,12)) 
).toDF("id", "list")

val df1 = df.groupBy("id").agg(collect_set($"list").as("col1"))
df1.show(false)

然后我尝试将WrappedArray行值转换为字符串,如下所示:

import org.apache.spark.sql.functions._
def arrayToString = udf((arr: collection.mutable.WrappedArray[collection.mutable.WrappedArray[String]]) => arr.flatten.mkString(", "))

val d = df1.withColumn("col1", arrayToString($"col1"))
d: org.apache.spark.sql.DataFrame = [id: int, col1: string]

scala> d.show(false)
+---+----------------------------+
|id |col1                        |
+---+----------------------------+
|1  |1, 2, 3, 5, 7, 9            |
|2  |4, 5, 6, 7, 8, 9, 10, 11, 12|
+---+----------------------------+

我真正想要的是生成如下输出:

+---+----------------------------+
|id |col1                        |
+---+----------------------------+
|1  |1$2$3, 5$7$ 9               |
|2  |4$5$6, 7$8$9, 10$11$12      |
+---+----------------------------+

我怎样才能做到这一点?

标签: scalaapache-sparkapache-spark-sqluser-defined-functions

解决方案


你不需要一个udf函数,一个简单的concat_ws应该为你做的伎俩

import org.apache.spark.sql.functions._
val df1 = df.withColumn("list", concat_ws("$", col("list")))
            .groupBy("id")
            .agg(concat_ws(", ", collect_set($"list")).as("col1"))

df1.show(false)

这应该给你

+---+----------------------+
|id |col1                  |
+---+----------------------+
|1  |1$2$3, 5$7$9          |
|2  |7$8$9, 4$5$6, 10$11$12|
+---+----------------------+

像往常一样,如果内置函数可用udf,则应避免使用函数,因为函数需要将列数据序列化和反序列化为基本类型以进行计算,并分别从基元到列udf

更简洁,你可以避免这withColumn一步

val df1 = df.groupBy("id")
            .agg(concat_ws(", ", collect_set(concat_ws("$", col("list")))).as("col1"))

我希望答案有帮助


推荐阅读