apache-spark - PySpark 中每组的滚动相关性和平均值(最后 3 个)
问题描述
我有一个这样的数据框
data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
(("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
(("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 1| 5|
|ID1| 2| 6|
|ID1| 3| 7|
|ID1| 4| 4|
|ID1| 5| 2|
|ID1| 6| 2|
|ID2| 1| 4|
|ID2| 2| 6|
|ID2| 3| 1|
|ID2| 4| 1|
|ID2| 5| 4|
+---+----+----+
我想计算每组最后 3 个元素的最后 3 个相关性和平均值。
Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65
Expected output is like this
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0| 5|
|ID1| 2| 6| 1| 5.5|
|ID1| 3| 7| 1| 6|
|ID1| 4| 4| -0.65| 5.66|
|ID1| 5| 2| -0.99| 4.33|
|ID1| 6| 2| -0.86| 2.66|
|ID2| 1| 4| 0| 4|
|ID2| 2| 6| 1| 5|
|ID2| 3| 1| -0.59| 3.66|
|ID2| 4| 1| -0.86| 2.66|
|ID2| 5| 4| 0.86| 2|
+---+----+----+----------+---------+
解决方案
您可以使用内置函数来做到这一点avg
,corr
这里是 scala 解决方案:
df
.withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
.withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
.withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
.drop($"indices")
.orderBy($"ID",$"colA")
.show()
给出:
+---+----+----+-------------------+------------------+
| ID|colA|colB| corr_last3| avg_last3|
+---+----+----+-------------------+------------------+
|ID1| 1| 5| 0.0| 5.0|
|ID1| 2| 6| 1.0| 5.5|
|ID1| 3| 7| 1.0| 6.0|
|ID1| 4| 4|-0.6546536707079772| 5.666666666666667|
|ID1| 5| 2|-0.9933992677987828| 4.333333333333333|
|ID1| 6| 2|-0.8660254037844386|2.6666666666666665|
|ID2| 1| 4| 0.0| 4.0|
|ID2| 2| 6| 1.0| 5.0|
|ID2| 3| 1|-0.5960395606792697|3.6666666666666665|
|ID2| 4| 1|-0.8660254037844387|2.6666666666666665|
|ID2| 5| 4| 0.8660254037844387| 2.0|
+---+----+----+-------------------+------------------+
推荐阅读
- java - 如何从 weblogic.jdbc.wrapper.Clob_oracle_sql_CLOB 获取字符串?
- youtube - 用于从频道获取最新视频的 YouTube API
- uipath - 如何在未安装 UiPath Studio 或 UiPath Robot 的 PC 上执行工作流程?
- rest - Composer REST Server Localhost:3000 拒绝连接
- mysql - Sequelize 不能让 'belongsto' 和 'hasmany' 关联
- reactjs - create-react-app:如何在 React 渲染之前加载 html
- excel - Excel VBA 中 SAS 加载项引用的 DLL 文件名
- javascript - Maximo Anywhere 从长描述中删除 HTML 标记
- php - woocommerce 有条件地根据数量隐藏运输方式
- r - 将以英寸为单位的高度字符向量转换为厘米?