首页 > 解决方案 > 如何在分组值上将一行的值与 PySpark 中的所有其他行进行比较

问题描述

问题陈述

考虑以下数据(见底部的代码生成)

+-----+-----+-------+--------+
|index|group|low_num|high_num|
+-----+-----+-------+--------+
|    0|    1|      1|       1|
|    1|    1|      2|       2|
|    2|    1|      3|       3|
|    3|    2|      1|       3|
+-----+-----+-------+--------+

然后对于给定的索引,我想计算一个索引high_numlow_num.low_numgroup

例如,考虑带有index:的第二行1Index:1group:1high_num2high_numon index1大于high_numon index 0,等于low_num,小于 on index 2。所以high_numof index: 1 大于 low_num整个组一次,所以我希望答案列中的值说1

具有所需输出的数据集

+-----+-----+-------+--------+-------+
|index|group|low_num|high_num|desired|
+-----+-----+-------+--------+-------+
|    0|    1|      1|       1|      0|
|    1|    1|      2|       2|      1|
|    2|    1|      3|       3|      2|
|    3|    2|      1|       3|      1|
+-----+-----+-------+--------+-------+

数据集生成代码

from pyspark.sql import SparkSession
spark = (
    SparkSession
    .builder
    .getOrCreate()
)
## Example df
## Note the inclusion of "desired" which is the desired output.
df = spark.createDataFrame(
    [
        (0, 1, 1, 1, 0),
        (1, 1, 2, 2, 1),
        (2, 1, 3, 3, 2),
        (3, 2, 1, 3, 1)
    ],
    schema=["index", "group", "low_num", "high_num", "desired"]
)

可能已经解决问题的伪代码

伪代码可能如下所示:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

w_spec = Window.partitionBy("group").rowsBetween(
    Window.unboundedPreceding, Window.unboundedFollowing)

## F.collect_list_when does not exist
## F.current_col does not exist
## Probably wouldn't work like this anyways
ddf = df.withColumn("Counts", 
                    F.size(F.collect_list_when(
                             F.current_col("high_number") > F.col("low_number"), 1
                          ).otherwise(None).over(w_spec))
                   )

标签: pysparkapache-spark-sqlaggregate

解决方案


您可以在filter上执行 acollect_list并检查其size

import pyspark.sql.functions as F

df2 = df.withColumn(
    'desired', 
    F.expr('size(filter(collect_list(low_num) over (partition by group), x -> x < high_num))')
)

df2.show()
+-----+-----+-------+--------+-------+
|index|group|low_num|high_num|desired|
+-----+-----+-------+--------+-------+
|    0|    1|      1|       1|      0|
|    1|    1|      2|       2|      1|
|    2|    1|      3|       3|      2|
|    3|    2|      1|       3|      1|
+-----+-----+-------+--------+-------+

推荐阅读