首页 > 解决方案 > 如何在 PySpark 中通过具有唯一值的列值标记连续重复项?

问题描述

我在 PySpark DataFrame 中有如下所示的数据:

| group | row | col |
+-------+-----+-----+
|   1   |  0  |  A  |
|   1   |  1  |  B  |
|   1   |  2  |  B  |
|   1   |  3  |  C  |
|   1   |  4  |  C  |
|   1   |  5  |  C  |
|   2   |  0  |  D  |
|   2   |  1  |  A  |
|   2   |  2  |  A  |
|   2   |  3  |  E  |
|   2   |  4  |  F  |
|   2   |  5  |  G  |
          ...

我想添加一个附加列,该列在按唯一值排序的情况下给出连续相同col值的每个“运行” (可以是字符串,int,并不重要)。grouprow

一个run可以轻松查看正在发生的事情的值选择是group、 start row、 endrow和重复col值的串联。对于上面的数据示例,这看起来像

| group | row | col |   run   |
+-------+-----+-----+---------+
|   0   |  0  |  A  | 0-0-0-A |
|   0   |  1  |  B  | 0-1-2-B |
|   0   |  2  |  B  | 0-1-2-B |
|   0   |  3  |  C  | 0-3-5-C |
|   0   |  4  |  C  | 0-3-5-C |
|   0   |  5  |  C  | 0-3-5-C |
|   1   |  0  |  D  | 1-0-0-D |
|   1   |  1  |  A  | 1-1-2-A |
|   1   |  2  |  A  | 1-1-2-A |
|   1   |  3  |  E  | 1-3-4-E |
|   1   |  4  |  E  | 1-3-4-E |
|   1   |  5  |  F  | 1-5-5-F |
          ...

我已经开始使用窗口函数来获得间隔的布尔分界:

win = Window.partitionBy('group').orderBy('row')
df = df.withColumn('next_col', f.lead('col').over(win))
df = df.withColumn('col_same', df['col'] == df['next_col'])

但似乎我必须使用调用f.lagcol_same获取实际的间隔(可能进入单独的列),然后调用另一个操作来run从这些额外的列中生成。我觉得可能有一种更简单、更有效的方法——任何建议都将不胜感激!

标签: pythonapache-sparkpysparkapache-spark-sql

解决方案


您可以使用lagandlead找到值col变化的边界:

df = spark_session.createDataFrame([
    Row(group=1, row=0, col='A'),
    Row(group=1, row=1, col='B'),
    Row(group=1, row=2, col='B'),
    Row(group=1, row=3, col='C'),
    Row(group=1, row=4, col='C'),
    Row(group=1, row=5, col='C'),
    Row(group=2, row=0, col='D'),
    Row(group=2, row=1, col='A'),
    Row(group=2, row=2, col='A'),
    Row(group=2, row=3, col='E'),
    Row(group=2, row=4, col='F'),
    Row(group=2, row=5, col='G'),
])

win = Window.partitionBy('group').orderBy('row')

df2 = df.withColumn('lag', lag('col').over(win)) \
    .withColumn('lead', lead('col').over(win)) \
    .withColumn('start', when(col('col') != coalesce(col('lag'), lit(-1)), col('row')))\
    .withColumn('end', when(col('col') != coalesce(col('lead'), lit(-1)), col('row')))\

df2.show()

输出:

+---+-----+---+----+----+-----+----+
|col|group|row| lag|lead|start| end|
+---+-----+---+----+----+-----+----+
|  A|    1|  0|null|   B|    0|   0|
|  B|    1|  1|   A|   B|    1|null|
|  B|    1|  2|   B|   C| null|   2|
|  C|    1|  3|   B|   C|    3|null|
|  C|    1|  4|   C|   C| null|null|
|  C|    1|  5|   C|null| null|   5|
|  D|    2|  0|null|   A|    0|   0|
|  A|    2|  1|   D|   A|    1|null|
|  A|    2|  2|   A|   E| null|   2|
|  E|    2|  3|   A|   F|    3|   3|
|  F|    2|  4|   E|   G|    4|   4|
|  G|    2|  5|   F|null|    5|   5|
+---+-----+---+----+----+-----+----+

要将信息放入问题中的单行,您可能需要再次洗牌:

win2 = Window.partitionBy('group', 'col')
df2.select(col('group'), col('col'), col('row'),
           concat_ws('-', col('group'), min('start').over(win2), max('end').over(win2), col('col')).alias('run'))\
    .orderBy('group', 'row')\
    .show()

输出:

+-----+---+---+-------+
|group|col|row|    run|
+-----+---+---+-------+
|    1|  A|  0|1-0-0-A|
|    1|  B|  1|1-1-2-B|
|    1|  B|  2|1-1-2-B|
|    1|  C|  3|1-3-5-C|
|    1|  C|  4|1-3-5-C|
|    1|  C|  5|1-3-5-C|
|    2|  D|  0|2-0-0-D|
|    2|  A|  1|2-1-2-A|
|    2|  A|  2|2-1-2-A|
|    2|  E|  3|2-3-3-E|
|    2|  F|  4|2-4-4-F|
|    2|  G|  5|2-5-5-G|
+-----+---+---+-------+

推荐阅读