首页 > 解决方案 > 用数据框中的位置替换数组中的元素 - Pyspark

问题描述

我有一个数据框:

|ID|CTA|
|------|
|11|1  |
|11|2  |
|11|7  |
|45|7  |

我需要按 ID 分组,并且每个 ID 的 ARRAY 长度为 7,但在 CTA 中有位置时指示 1

所以我的输出数据框应该如下所示:

|ID|CTAS             |
|------------------- |
|11|[1,1,0,0,0,0,1]  |
|45|[0,0,0,0,0,0,1]  |

你能帮助我吗?


更新

代码

如何将零留在数组中?

标签: pythondataframeapache-sparkpysparkapache-spark-sql

解决方案


您可以应用TRANSFORM表达式并迭代sequence(1, 7)以检查该值是否包含在 CTAS 列中:

import pyspark.sql.functions as f

group_df = (df
            .groupBy('ID')
            .agg(f.collect_list('CTA').alias('CTAS')))
# +---+---------+
# |ID |CTAS     |
# +---+---------+
# |11 |[1, 2, 7]|
# |45 |[7]      |
# +---+---------+

pos_df = (group_df
          .withColumn('CTAS', 
                      f.expr('transform(sequence(1, 7), value -> cast(array_contains(CTAS, value) as int))')))
pos_df.sort('ID').show(truncate=False)
# +---+---------------------+
# |ID |CTAS                 |
# +---+---------------------+
# |11 |[1, 1, 0, 0, 0, 0, 1]|
# |45 |[0, 0, 0, 0, 0, 0, 1]|
# +---+---------------------+

推荐阅读