首页 > 解决方案 > PySpark 中每组的滚动相关性和平均值(最后 3 个)

问题描述

我有一个这样的数据框

data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
    (("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
    (("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()

+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1|   1|   5|
|ID1|   2|   6|
|ID1|   3|   7|
|ID1|   4|   4|
|ID1|   5|   2|
|ID1|   6|   2|
|ID2|   1|   4|
|ID2|   2|   6|
|ID2|   3|   1|
|ID2|   4|   1|
|ID2|   5|   4|
+---+----+----+

我想计算每组最后 3 个元素的最后 3 个相关性和平均值。

Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65


Expected output is like this

    +---+----+----+----------+---------+
    | ID|colA|colB|corr_last3|avg_last3|
    +---+----+----+----------+---------+
    |ID1|   1|   5|         0|        5|
    |ID1|   2|   6|         1|      5.5|
    |ID1|   3|   7|         1|        6|
    |ID1|   4|   4|     -0.65|     5.66|
    |ID1|   5|   2|     -0.99|     4.33|
    |ID1|   6|   2|     -0.86|     2.66|
    |ID2|   1|   4|         0|        4|
    |ID2|   2|   6|         1|        5|
    |ID2|   3|   1|     -0.59|     3.66|
    |ID2|   4|   1|     -0.86|     2.66|
    |ID2|   5|   4|      0.86|        2|
    +---+----+----+----------+---------+

标签: apache-sparkpysparkapache-spark-sqlpyspark-dataframes

解决方案


您可以使用内置函数来做到这一点avgcorr这里是 scala 解决方案:

df
  .withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
  .withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
  .withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
  .drop($"indices")
  .orderBy($"ID",$"colA")
  .show() 

给出:

+---+----+----+-------------------+------------------+
| ID|colA|colB|         corr_last3|         avg_last3|
+---+----+----+-------------------+------------------+
|ID1|   1|   5|                0.0|               5.0|
|ID1|   2|   6|                1.0|               5.5|
|ID1|   3|   7|                1.0|               6.0|
|ID1|   4|   4|-0.6546536707079772| 5.666666666666667|
|ID1|   5|   2|-0.9933992677987828| 4.333333333333333|
|ID1|   6|   2|-0.8660254037844386|2.6666666666666665|
|ID2|   1|   4|                0.0|               4.0|
|ID2|   2|   6|                1.0|               5.0|
|ID2|   3|   1|-0.5960395606792697|3.6666666666666665|
|ID2|   4|   1|-0.8660254037844387|2.6666666666666665|
|ID2|   5|   4| 0.8660254037844387|               2.0|
+---+----+----+-------------------+------------------+

推荐阅读