首页 > 解决方案 > 如何对 Spark 数据框中的记录组执行任意计算?

问题描述

我有一个这样的数据框:

|-----+-----+-------+---------|
| foo | bar | fox   | cow     |
|-----+-----+-------+---------|
|   1 |   2 | red   | blue    | // row 0
|   1 |   2 | red   | yellow  | // row 1
|   2 |   2 | brown | green   | // row 2
|   3 |   4 | taupe | fuschia | // row 3
|   3 |   4 | red   | orange  | // row 4
|-----+-----+-------+---------|

我需要按“foo”和“bar”对记录进行分组,然后对“fox”和“cow”执行一些神奇的计算以生成“badger”,它可以插入或删除记录:

|-----+-----+-------+---------+---------|
| foo | bar | fox   | cow     | badger  |
|-----+-----+-------+---------+---------|
|   1 |   2 | red   | blue    | zebra   |
|   1 |   2 | red   | blue    | chicken |
|   1 |   2 | red   | yellow  | cougar  |
|   2 |   2 | brown | green   | duck    |
|   3 |   4 | red   | orange  | peacock |
|-----+-----+-------+---------+---------|

(在此示例中,第 0 行已拆分为两个“badger”值,第 3 行已从最终输出中删除。)

到目前为止,我最好的方法如下所示:

val groups = df.select("foo", "bar").distinct
groups.flatMap(row => {
  val (foo, bar): (String, String) = (row(0), row(1))
  val group: DataFrame = df.where(s"foo == '$foo' AND bar == '$bar'")
  val rowsWithBadgers: List[Row] = makeBadgersFor(group)
  rowsWithBadgers
})

这种方法有几个问题:

  1. 单独匹配很foo笨拙bar。(实用方法可以清理它,所以没什么大不了的。)
  2. Invalid tree: null\nnull由于我尝试df从内部引用的嵌套操作,它会引发错误groups.flatMap。不知道如何绕过那个。
  3. 我不确定这种映射和过滤是否真正正确地利用了 Spark 分布式计算。

有没有更高效和/或优雅的方法来解决这个问题?

这个问题与Spark DataFrame: operation on groups非常相似,但我将其包含在此处是因为 1)不清楚该问题是否需要添加和删除记录,以及 2)该问题中的答案已过时并且缺乏细节。

我看不到使用groupBy用户定义的聚合函数来完成此操作的方法,因为聚合函数聚合到单行。换句话说,

udf(<records with foo == 'foo' && bar == 'bar'>) => [foo,bar,aggregatedValue]

在分析我的组后,我可能需要返回两个或更多不同的行或零行。我没有看到聚合函数这样做的方法——如果你有一个例子,请分享。

标签: scalaapache-spark

解决方案


可以使用用户定义的函数。返回的单行可以包含一个列表。然后,您可以将列表分解为多行并重建列。

聚合器:

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders.kryo
import org.apache.spark.sql.expressions.Aggregator

case class StuffIn(foo: BigInt, bar: BigInt, fox: String, cow: String)
case class StuffOut(foo: BigInt, bar: BigInt, fox: String, cow: String, badger: String)
object StuffOut {
  def apply(stuffIn: StuffIn): StuffOut = new StuffOut(stuffIn.foo, 
stuffIn.bar, stuffIn.fox, stuffIn.cow, "dummy")
}

object MultiLineAggregator extends Aggregator[StuffIn, Seq[StuffOut], Seq[StuffOut]] {
  def zero: Seq[StuffOut] = Seq[StuffOut]()
  def reduce(buffer: Seq[StuffOut], stuff: StuffIn): Seq[StuffOut] = {
    makeBadgersForDummy(buffer, stuff)
  }

  def merge(b1: Seq[StuffOut], b2: Seq[StuffOut]): Seq[StuffOut] = {
    b1 ++: b2
  }
  def finish(reduction: Seq[StuffOut]): Seq[StuffOut] = reduction
  def bufferEncoder: Encoder[Seq[StuffOut]] = kryo[Seq[StuffOut]]
  def outputEncoder: Encoder[Seq[StuffOut]] = kryo[Seq[StuffOut]]
}

来电:

val averageSalary: TypedColumn[StuffIn, Seq[StuffOut]] = MultiLineAggregator.toColumn

val res: DataFrame =
  ds.groupByKey(x => (x.foo, x.bar))
          .agg(averageSalary)
          .map(_._2)
          .withColumn("value", explode($"value"))
          .withColumn("foo", $"value.foo")
          .withColumn("bar", $"value.bar")
          .withColumn("fox", $"value.fox")
          .withColumn("cow", $"value.cow")
          .withColumn("badger", $"value.badger")
          .drop("value")

推荐阅读