scala - Spark - 仅对几个最小的项目进行分组和聚合
问题描述
简而言之
我有两个数据框和函数的笛卡尔积(交叉连接),它为这个产品的给定元素提供了一些分数。我现在想为第一个 DF 的每个成员获取第二个 DF 的几个“最佳匹配”元素。
详细介绍
下面是一个简化的示例,因为我的真实代码有些臃肿,带有额外的字段和过滤器。
给定两组数据,每组都有一些 id 和 value:
// simple rdds of tuples
val rdd1 = sc.parallelize(Seq(("a", 31),("b", 41),("c", 59),("d", 26),("e",53),("f",58)))
val rdd2 = sc.parallelize(Seq(("z", 16),("y", 18),("x",3),("w",39),("v",98), ("u", 88)))
// convert them to dataframes:
val df1 = spark.createDataFrame(rdd1).toDF("id1", "val1")
val df2 = spark.createDataFrame(rdd2).toDF("id2", "val2")
和一些函数,对于来自第一个和第二个数据集的元素对给出它们的“匹配分数”:
def f(a:Int, b:Int):Int = (a * a + b * b * b) % 17
// convert it to udf
val fu = udf((a:Int, b:Int) => f(a, b))
我们可以创建两组的乘积并计算每对的分数:
val dfc = df1.crossJoin(df2)
val r = dfc.withColumn("rez", fu(col("val1"), col("val2")))
r.show
+---+----+---+----+---+
|id1|val1|id2|val2|rez|
+---+----+---+----+---+
| a| 31| z| 16| 8|
| a| 31| y| 18| 10|
| a| 31| x| 3| 2|
| a| 31| w| 39| 15|
| a| 31| v| 98| 13|
| a| 31| u| 88| 2|
| b| 41| z| 16| 14|
| c| 59| z| 16| 12|
...
现在我们想让这个结果按以下方式分组id1
:
r.groupBy("id1").agg(collect_set(struct("id2", "rez")).as("matches")).show
+---+--------------------+
|id1| matches|
+---+--------------------+
| f|[[v,2], [u,8], [y...|
| e|[[y,5], [z,3], [x...|
| d|[[w,2], [x,6], [v...|
| c|[[w,2], [x,6], [v...|
| b|[[v,2], [u,8], [y...|
| a|[[x,2], [y,10], [...|
+---+--------------------+
但实际上我们只想保留少数(比如 3 个)“匹配项”,即那些得分最高(比如得分最低)的匹配项。
问题是
如何将“匹配”排序并减少到前 N 个元素?可能是关于 collect_list 和 sort_array 的东西,虽然我不知道如何按内部字段排序。
有没有办法确保在大输入 DF 的情况下进行优化 - 例如在聚合时直接选择最小值。我知道如果我在没有火花的情况下编写代码,这可以很容易地完成 - 为每个元素保留小数组或优先级队列,
id1
并在它应该在的地方添加元素,可能会删除之前添加的一些元素。
例如,交叉连接是昂贵的操作是可以的,但我想避免将内存浪费在我将在下一步中删除的大部分结果上。我的真实用例处理条目少于 100 万个的 DF,因此交叉连接仍然可行,但由于我们只想为每个匹配选择 10-20 个顶级匹配,id1
因此似乎非常希望不要在步骤之间保留不必要的数据。
解决方案
首先,我们只需要取前 n 行。为此,我们按“id1”对 DF 进行分区,并按 res 对组进行排序。我们使用它将行号列添加到 DF,就像我们可以使用where函数获取前 n 行一样。您可以继续执行您编写的相同代码。按“id1”分组并收集列表。只是现在你已经有了最高的行。
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val n = 3
val w = Window.partitionBy($"id1").orderBy($"res".desc)
val res = r.withColumn("rn", row_number.over(w)).where($"rn" <= n).groupBy("id1").agg(collect_set(struct("id2", "res")).as("matches"))
第二个选项可能更好,因为您不需要将 DF 分组两次:
val sortTakeUDF = udf{(xs: Seq[Row], n: Int)} => xs.sortBy(_.getAs[Int]("res")).reverse.take(n).map{case Row(x: String, y:Int)}}
r.groupBy("id1").agg(sortTakeUDF(collect_set(struct("id2", "res")), lit(n)).as("matches"))
在这里,我们创建一个带有数组列和整数值 n 的 udf。udf 按您的“res”对数组进行排序,并仅返回前 n 个元素。
推荐阅读
- c++ - C++中常量初始化和零初始化顺序的矛盾定义
- vim - 有没有办法在vim中向上或向下复制选定区域?
- flutter - 如何在 Flutter 中更新现有文件(Path_Provider)
- r - 错误:连接到互联网时,如果没有互联网,天桥查询不可用
- php - 如何防止 Laravel 应用程序文件夹被搜索引擎索引
- slack - slack huddle 没有突然出现
- flutter - Flutter 中的 SearchDelegate:在 null 上调用了“where”方法
- html - 如何在引导手风琴的标题上添加按钮,以便单击它不会折叠或展开它?
- html - 反应性 R 闪亮与 HTML:NA 值保存问题
- ruby-on-rails - Rails 按键分组哈希数组