首页 > 解决方案 > 在 pyspark 中创建相互依赖的列

问题描述

我有一个pyspark如下所示的数据框

df = spark.createDataFrame([
(124,10,8),
(124,20,7),
(125,30,6),
(125,40,5),
(126,50,4),
(126,60,3),
(126,70,2),
(127,80,1)],("ACC_KEY", "AMT", "value"))

df.show()

+-------+---+-----+
|ACC_KEY|AMT|value|
+-------+---+-----+
|    126| 70|    2|
|    126| 60|    3|
|    126| 50|    4|
|    124| 20|    7|
|    124| 10|    8|
|    127| 80|    1|
|    125| 40|    5|
|    125| 30|    6|
+-------+---+-----+

Expected result

+-------+---+-----+-------+-----+-------+
|ACC_KEY|AMT|value|row_now|amt_c|lkp_rev|
+-------+---+-----+-------+-----+-------+
|    126| 70|    2|      1|   70|     72|
|    126| 60|    3|      2|   72|     75|
|    126| 50|    4|      3|   75|     79|
|    124| 20|    7|      1|   20|     27|
|    124| 10|    8|      2|   27|     35|
|    127| 80|    1|      1|   80|     81|
|    125| 40|    5|      1|   40|     45|
|    125| 30|    6|      2|   45|     51|
+-------+---+-----+-------+-----+-------+

Conditions

1) When row_number = 1 then amt_c column = column AMT
2) when row_number != 1 then It should be the lag of column lkp_rev + column value
3) lkp_rev column = amt_c column + value column

我试过如下

import pyspark.sql.functions as f
from pyspark.sql import Window

# create row_number column
df1 = df.withColumn("row_now", f.row_number().over(Window.partitionBy("ACC_KEY").orderBy(f.col('AMT').desc())))

# amt_c column creation
df2 = df1.withColumn("amt_c", f.when(f.col("row_now") == 1, f.col("AMT")).otherwise(f.col("value") + f.col("AMT")))

我怎样才能达到我想要的

标签: apache-sparkpyspark

解决方案


我认为,如果将所有具有 的行分开row_now = 1,并将其作为“参考”数据框或每个acc_key.

首先,添加行号以便我们以后可以重用

df = df.withColumn('row_now', F.row_number().over(W.partitionBy('acc_key').orderBy(F.col('amt').desc())))
# +-------+---+-----+-------+
# |acc_key|amt|value|row_now|
# +-------+---+-----+-------+
# |    126| 70|    2|      1|
# |    126| 60|    3|      2|
# |    126| 50|    4|      3|
# |    124| 20|    7|      1|
# |    124| 10|    8|      2|
# |    127| 80|    1|      1|
# |    125| 40|    5|      1|
# |    125| 30|    6|      2|
# +-------+---+-----+-------+

我们现在需要制作一个“参考”数据框,它只包含初始数量(即row_now = 1

ref = (df
    .where(F.col('row_now') == 1)
    .drop('row_now', 'value')
    .withColumnRenamed('amt', 'init_amt')
)
# +-------+--------+
# |acc_key|init_amt|
# +-------+--------+
# |    126|      70|
# |    124|      20|
# |    127|      80|
# |    125|      40|
# +-------+--------+

最后,加入原来的,这样我们就有了应用lag功能的起点

(df
    .join(ref, ['acc_key'])
    .withColumn('temp', F
        .when(F.col('row_now') == 1, F.col('init_amt'))
        .otherwise(F.lag('value').over(W.partitionBy('acc_key').orderBy('row_now')))
    )
    .withColumn('amt_c', F.sum('temp').over(W.partitionBy('acc_key').orderBy('row_now')))
    .withColumn('lkp_rev', F.col('amt_c') + F.col('value'))
    .drop('init_amt', 'temp')
    .show()
)

# +-------+---+-----+-------+-----+-------+
# |acc_key|amt|value|row_now|amt_c|lkp_rev|
# +-------+---+-----+-------+-----+-------+
# |    126| 70|    2|      1|   70|     72|
# |    126| 60|    3|      2|   72|     75|
# |    126| 50|    4|      3|   75|     79|
# |    124| 20|    7|      1|   20|     27|
# |    124| 10|    8|      2|   27|     35|
# |    127| 80|    1|      1|   80|     81|
# |    125| 40|    5|      1|   40|     45|
# |    125| 30|    6|      2|   45|     51|
# +-------+---+-----+-------+-----+-------+

推荐阅读