首页 > 解决方案 > 在多列上使用 collect_list 和 collect_set 时如何保留列表的顺序?

问题描述

使用 collect_list 时如何保留列的顺序?我有一个日期列 (col1),当我在其上调用 collect_list 函数时,不会保留顺序。这是我的代码,带有示例输入和输出。

输入数据框:


df = sqlContext.createDataFrame([('1', 201001,3400,1600,65,320,400,), ('1', 201002,5200,1600,65,320,400,), ('1', 201003,65,1550,32,320,400,), ('2', 201505,3200,1800,12,1,40,), ('2', 201508,3200,3200,12,1,40,), ('3', 201412,40,40,12,1,3,)], 
                                  ['ColA', 'Col1','Col2','Col3','Col4','Col5','Col6',])

+----+------+----+----+----+----+----+
|ColA|  Col1|Col2|Col3|Col4|Col5|Col6|
+----+------+----+----+----+----+----+
|   1|201001|3400|1600|  65| 320| 400|
|   1|201002|5200|1600|  65| 320| 400|
|   1|201003|  65|1550|  32| 320| 400|
|   2|201505|3200|1800|  12|   1|  40|
|   2|201508|3200|3200|  12|   1|  40|
|   3|201412|  40|  40|  12|   1|   3|
+----+------+----+----+----+----+----+

预期输出:

df = sqlContext.createDataFrame([(1,['201001', '201002', '201003'],[3400, 5200, 65],[1600, 1600, 1550],[65,32],[320],[400],), (2,['201505', '201508'],[3200, 3200],[1800, 3200],[12],[1],[40],),
(3,['201412'],[40],[40],[12],[1],[3],)], ['ColA', 'Col1','Col2','Col3','Col4','Col5','Col6',])
df.show()

+----+--------------------+----------------+------------------+--------+-----+-----+
|ColA|                Col1|            Col2|              Col3|    Col4| Col5| Col6|
+----+--------------------+----------------+------------------+--------+-----+-----+
|   1|[201001, 201002, ...|[3400, 5200, 65]|[1600, 1600, 1550]|[65, 32]|[320]|[400]|
|   2|    [201505, 201508]|    [3200, 3200]|      [1800, 3200]|    [12]|  [1]| [40]|
|   3|            [201412]|            [40]|              [40]|    [12]|  [1]|  [3]|
+----+--------------------+----------------+------------------+--------+-----+-----+

这是有效但不存储 col1 顺序的代码:


def aggregation(df, groupby_column, cols_to_list, cols_to_set):
  exprs = [F.collect_list(F.col(c)).alias(c) for c in cols_to_list]\
          + [F.collect_set(F.col(c)).alias(c) for c in cols_to_set]
  return df.groupby(*groupby_column).agg(*exprs)

groupby_column = ['ColA']

cols_to_list = ['Col1', 'Col2', 'Col3',]
cols_to_set = ['Col4', 'Col5', 'Col6',]

df = aggregation(df, groupby_column, cols_to_list, cols_to_set)

标签: apache-sparkpysparkapache-spark-sql

解决方案


感谢@pault,我能够理解问题所在。另一个页面上发布的解决方案很复杂,特别是如果您有太多列要使用并计划同时使用 collect_list 和 collect_set 函数。我能够通过执行与重新分区配对的 orderBy 来解决它,这样我的所有数据都在一个分区上,而不是导致问题开始的多个分区上。请记住,重新分区是一项昂贵的操作,因此请谨慎使用。

这是任何人列表的代码:


def aggregation(df, groupby_column, cols_to_list, cols_to_set):
  df = df.orderBy(colA).repartition(1)
  exprs = [F.collect_list(F.col(c)).alias(c) for c in cols_to_list]\
          + [F.collect_set(F.col(c)).alias(c) for c in cols_to_set]
  return df.groupby(*groupby_column).agg(*exprs)

groupby_column = ['ColA']

cols_to_list = ['Col1', 'Col2', 'Col3',]
cols_to_set = ['Col4', 'Col5', 'Col6',]

df = aggregation(df, groupby_column, cols_to_list, cols_to_set)

推荐阅读