首页 > 解决方案 > spark scala - 按数组列分组

问题描述

我对火花斯卡拉很陌生。感谢您的帮助..我有一个数据框

val df = Seq(
  ("a", "a1", Array("x1","x2")), 
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")

我正在寻找一种按 k1 和 k3 对其进行分组并将 k2 收集到数组中的方法。但是,k3 是一个数组,我需要为分组应用包含(而不是完全匹配)。换句话说,我正在寻找这样的结果

k1   k3       k2                count
a   (x1,x2)   (a1,b1,c1,e1)     4
a    (x3)      (d1)             1
c    (x2)      (c3)             1

有人可以建议如何实现这一目标吗?

提前致谢!

标签: arraysscalaapache-sparkmapreduce

解决方案


我建议您按 k1 列分组收集 k2 和 k3 的结构列表将收集的列表传递给 udf 函数,以计算 k3 中的数组何时包含在另一个 k3 数组中并添加 k2 的元素。

然后您可以使用explodeselect表达式来获得所需的输出

以下是完整的工作解决方案

val df = Seq(
  ("a", "a1", Array("x1","x2")),
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
  ).toDF("k1", "k2", "k3")

import org.apache.spark.sql.functions._
def containsGoupingUdf = udf((arr: Seq[Row]) => {
  val firstStruct =  arr.head
  val tailStructs =  arr.tail
  var result = Array((collection.mutable.Set(firstStruct.getAs[String]("k2")), firstStruct.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
  for(str <- tailStructs){
    var added = false
    for((res, index) <- result.zipWithIndex) {
      if (str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").exists(res._2) || res._2.exists(x => str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").contains(x))) {
        result(index) = (res._1 + str.getAs[String]("k2"), res._2 ++ str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, res._3 + 1)
        added = true
      }
    }
    if(!added){
      result = result ++ Array((collection.mutable.Set(str.getAs[String]("k2")), str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
    }
  }
  result.map(tuple => (tuple._1.toArray, tuple._2.toArray, tuple._3))
})

df.groupBy("k1").agg(containsGoupingUdf(collect_list(struct(col("k2"), col("k3")))).as("aggregated"))
    .select(col("k1"), explode(col("aggregated")).as("aggregated"))
    .select(col("k1"), col("aggregated._2").as("k3"), col("aggregated._1").as("k2"), col("aggregated._3").as("count"))
  .show(false)

这应该给你

+---+--------+----------------+-----+
|k1 |k3      |k2              |count|
+---+--------+----------------+-----+
|c  |[x2]    |[c3]            |1    |
|a  |[x1, x2]|[b1, e1, c1, a1]|4    |
|a  |[x3]    |[d1]            |1    |
+---+--------+----------------+-----+ 

我希望答案是有帮助的,您可以根据自己的需要进行修改。


推荐阅读