首页 > 解决方案 > PySpark Dataframe 转置为列表

问题描述

我正在使用 pyspark sql api,并尝试将具有重复值的行分组到其余内容列表中。它类似于转置,但不是旋转所有值,而是将值放入数组中。

电流输出:

group_id | member_id | name
55       | 123       | jake
55       | 234       | tim 
65       | 345       | chris

期望的输出:

group_id | members
55       | [[123, 'jake'], [234, 'tim']]
65       | [345, 'chris']

标签: pysparkpyspark-sql

解决方案


您需要groupbygroup_idandpyspark.sql.functions.collect_list()用作聚合函数。

至于组合member_idname列,您有两种选择:

选项 1:使用pyspark.sql.functions.array

from pyspark.sql.functions import array, collect_list

df1 = df.groupBy("group_id")\
    .agg(collect_list(array("member_id", "name")).alias("members"))

df1.show(truncate=False)
#+--------+-------------------------------------------------+
#|group_id|members                                          |
#+--------+-------------------------------------------------+
#|55      |[WrappedArray(123, jake), WrappedArray(234, tim)]|
#|65      |[WrappedArray(345, chris)]                       |
#+--------+-------------------------------------------------+

这将返回一个WrappedArray字符串数组。整数被转换为字符串,因为你不能有混合类型的数组。

df1.printSchema()
#root
# |-- group_id: integer (nullable = true)
# |-- members: array (nullable = true)
# |    |-- element: array (containsNull = true)
# |    |    |-- element: string (containsNull = true)

选项 2:使用pyspark.sql.functions.struct

from pyspark.sql.functions import collect_list, struct 

df2 = df.groupBy("group_id")\
    .agg(collect_list(struct("member_id", "name")).alias("members"))

df2.show(truncate=False)
#+--------+-----------------------+
#|group_id|members                |
#+--------+-----------------------+
#|65      |[[345,chris]]          |
#|55      |[[123,jake], [234,tim]]|
#+--------+-----------------------+

这将返回一个结构数组,其中包含用于member_id和的命名字段name

df2.printSchema()
#root
# |-- group_id: integer (nullable = true)
# |-- members: array (nullable = true)
# |    |-- element: struct (containsNull = true)
# |    |    |-- member_id: integer (nullable = true)
# |    |    |-- name: string (nullable = true)

struct 方法的有用之处在于您可以使用点访问器按名称访问嵌套数组的元素:

df2.select("group_id", "members.member_id").show()
#+--------+----------+
#|group_id| member_id|
#+--------+----------+
#|      65|     [345]|
#|      55|[123, 234]|
#+--------+----------+

推荐阅读