scala - 如何使用 Scala 在 Spark 中进行滑动窗口排名?
问题描述
我有一个数据集:
+-----+-------------------+---------------------+------------------+
|query|similar_queries |model_score |count |
+-----+-------------------+---------------------+------------------+
|shirt|funny shirt |0.0034038130658784866|189.0 |
|shirt|shirt womens |0.0019435265241921438|136.0 |
|shirt|watch |0.001097496453284101 |212.0 |
|shirt|necklace |6.694577024597908E-4 |151.0 |
|shirt|white shirt |0.0037413097560623485|217.0 |
|shirt|shoes |0.0022062579255572733|575.0 |
|shirt|crop top |9.065831060804897E-4 |173.0 |
|shirt|polo shirts for men|0.007706416273211698 |349.0 |
|shirt|shorts |0.002669621942466027 |200.0 |
|shirt|black shirt |0.03264296242546658 |114.0 |
+-----+-------------------+---------------------+------------------+
我首先根据“计数”对数据集进行排名。
lazy val countWindowByFreq = Window.partitionBy(col(QUERY)).orderBy(col(COUNT).desc)
val ranked_data = data.withColumn("count_rank", row_number over countWindowByFreq)
+-----+-------------------+---------------------+------------------+----------+
|query|similar_queries |model_score |count |count_rank|
+-----+-------------------+---------------------+------------------+----------+
|shirt|shoes |0.0022062579255572733|575.0 |1 |
|shirt|polo shirts for men|0.007706416273211698 |349.0 |2 |
|shirt|white shirt |0.0037413097560623485|217.0 |3 |
|shirt|watch |0.001097496453284101 |212.0 |4 |
|shirt|shorts |0.002669621942466027 |200.0 |5 |
|shirt|funny shirt |0.0034038130658784866|189.0 |6 |
|shirt|crop top |9.065831060804897E-4 |173.0 |7 |
|shirt|necklace |6.694577024597908E-4 |151.0 |8 |
|shirt|shirt womens |0.0019435265241921438|136.0 |9 |
|shirt|black shirt |0.03264296242546658 |114.0 |10 |
+-----+-------------------+---------------------+------------------+----------+
我现在正在尝试使用 row_number(4 行)上的滚动窗口对内容进行排名,并根据 model_score 在窗口内排名。例如:
在第一个窗口 row_number 1 到 4 中,新排名(新列)将为
1. polo shirts for men
2. white shirt
3. shoes
4. watch
在第一个窗口中,row_number 5 到 8,新排名(新列)将为
5. funny shirt
6. shorts
7. shirt womens
8. crop top
在第一个窗口,row_number 9 休息,新的排名(新列)将是
9. black shirt
10. shirt womens
有人可以告诉我如何使用这个 spark 和 Scala 来实现吗?有没有我可以使用的预定义函数?
我试过 :
惰性值 MODEL_RANK = Window.partitionBy(col(QUERY)) .orderBy(col(MODEL_SCORE).desc).rowsBetween(0, 3)
但这给了我:
sql.AnalysisException: Window Frame ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING must match the required frame ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW;
另外,尝试使用 .rowsBetween(-3, 0) 但这也给了我错误:
org.apache.spark.sql.AnalysisException: Window Frame ROWS BETWEEN 3 PRECEDING AND CURRENT ROW must match the required frame ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW;
解决方案
既然您已经count_rank
计算过了,下一步是找到一种方法将行分组为一组四人组。可以按如下方式完成:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val ranked_data_grouped = ranked_data
.withColumn("bucket", (($"count_rank" -1)/4).cast(IntegerType))
rank_data_grouped 看起来像:
+-----+-------------------+---------------------+------------------+----------+-------+
|query|similar_queries |model_score |count |count_rank|bucket |
+-----+-------------------+---------------------+------------------+----------+-------+
|shirt|shoes |0.0022062579255572733|575.0 |1 |0 |
|shirt|polo shirts for men|0.007706416273211698 |349.0 |2 |0 |
|shirt|white shirt |0.0037413097560623485|217.0 |3 |0 |
|shirt|watch |0.001097496453284101 |212.0 |4 |0 |
|shirt|shorts |0.002669621942466027 |200.0 |5 |1 |
|shirt|funny shirt |0.0034038130658784866|189.0 |6 |1 |
|shirt|crop top |9.065831060804897E-4 |173.0 |7 |1 |
|shirt|necklace |6.694577024597908E-4 |151.0 |8 |1 |
|shirt|shirt womens |0.0019435265241921438|136.0 |9 |2 |
|shirt|black shirt |0.03264296242546658 |114.0 |10 |2 |
+-----+-------------------+---------------------+------------------+----------+-------+
现在,您所要做的就是分区bucket
& 排序model_score
:
val output = ranked_data_grouped
.withColumn("finalRank", row_number().over(Window.partitionBy($"bucket").orderBy($"model_score".desc)))
推荐阅读
- mysql - 需要一些帮助来清除没有约束的 MySQL 表中的重复项
- javascript - 随机对话框 jQuery + JavaScript
- html - 我如何获取网页网址并将其填写在 Matlab 中
- c# - EntityFrameworkCore - 使用来自 http 正文的接收对象过滤现有数据
- php - GD 无法创建 JPG
- linux - libcurl 中不支持或禁用协议“https”-如何检查当前版本的 libcurl
- powerbi - 带有过滤器的最后日期未在 Powerbi 中显示所有日期
- flutter - Flutter TextButton onPrimary 替代品
- swift - scrollToRow(...animated:false) 导致故障
- flutter - 如何在颤动中获取具有相同ID的所有文档快照