首页 > 解决方案 > pyspark 检查每个名字是否有3个数据

问题描述

在 pyspark 中,我有一个 DataFrame,如下所示。我想检查每个名称是否有 3 个动作数据(0、1、2)。如果有缺失,则添加一个新行,将分数列设置为 0,其他列不变(例如:str1、str2、str3)。

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
+-----+--------+--------+--------+-------+-------+

比如名字B没有动作2,新增一行数据如下

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  B  | str_B1 | str_B2 | str_B3 |      2|      0|<---- new row data
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
+-----+--------+--------+--------+-------+-------+

也有可能一个名字只有一个行数据,需要新增两个行数据。

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  B  | str_B1 | str_B2 | str_B3 |      2|      0| 
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
|  D  | str_D1 | str_D2 | str_D3 |      0|     45|
+-----+--------+--------+--------+-------+-------+

+-----+--------+--------+--------+-------+-------+
| name|  str1  |  str2  |  str3  | action| score |
+-----+--------+--------+--------+-------+-------+
|  A  | str_A1 | str_A2 | str_A3 |      0|      2|
|  A  | str_A1 | str_A2 | str_A3 |      1|      6|
|  A  | str_A1 | str_A2 | str_A3 |      2|     74|
|  B  | str_B1 | str_B2 | str_B3 |      0|     59|
|  B  | str_B1 | str_B2 | str_B3 |      1|     18|
|  B  | str_B1 | str_B2 | str_B3 |      2|      0| 
|  C  | str_C1 | str_C2 | str_C3 |      0|      3|
|  C  | str_C1 | str_C2 | str_C3 |      1|     33|
|  C  | str_C1 | str_C2 | str_C3 |      2|      3|
|  D  | str_D1 | str_D2 | str_D3 |      0|     45|
|  D  | str_D1 | str_D2 | str_D3 |      1|      0|<---- new row data
|  D  | str_D1 | str_D2 | str_D3 |      2|      0|<---- new row data
+-----+--------+--------+--------+-------+-------+

我是 pyspark 的新手,不知道如何执行此操作。谢谢您的帮助。

标签: pysparkapache-spark-sql

解决方案


使用 UDF 的解决方案

from pyspark.sql import functions as F, types as T

@F.udf(T.MapType(T.StringType(), T.IntegerType()))
def add_missing_values(values):
    return {i: values.get(i, 0) for i in range(3)}

df = (
    df.groupBy("name", "str1", "str2", "str3")
    .agg(
        F.map_from_entries(F.collect_list(F.struct("action", "score"))).alias("values")
    )
    .withColumn("values", add_missing_values(F.col("values")))
    .select(
        "name", "str1", "str2", "str3", F.explode("values").alias("action", "score")
    )
)

df.show()

+----+------+------+------+------+-----+                                        
|name|  str1|  str2|  str3|action|score|
+----+------+------+------+------+-----+
|   A|str_A1|str_A2|str_A3|     0|    2|
|   A|str_A1|str_A2|str_A3|     1|    6|
|   A|str_A1|str_A2|str_A3|     2|   74|
|   B|str_B1|str_B2|str_B3|     0|   59|
|   B|str_B1|str_B2|str_B3|     1|   18|
|   B|str_B1|str_B2|str_B3|     2|    0|<---- new row data
|   C|str_C1|str_C2|str_C3|     0|    3|
|   C|str_C1|str_C2|str_C3|     1|   33|
|   C|str_C1|str_C2|str_C3|     2|    3|
|   D|str_D1|str_D2|str_D3|     0|   45|
|   D|str_D1|str_D2|str_D3|     1|    0|<---- new row data
|   D|str_D1|str_D2|str_D3|     2|    0|<---- new row data
+----+------+------+------+------+-----+

完整的 Spark 解决方案:

df = (
    df.groupBy("name", "str1", "str2", "str3")
    .agg(
        F.map_from_entries(F.collect_list(F.struct("action", "score"))).alias("values")
    )
    .withColumn(
        "values",
        F.map_from_arrays(
            F.array([F.lit(i) for i in range(3)]),
            F.array(
                [F.coalesce(F.col("values").getItem(i), F.lit(0)) for i in range(3)]
            ),
        ),
    )
    .select(
        "name", "str1", "str2", "str3", F.explode("values").alias("action", "score")
    )
)

推荐阅读