首页 > 解决方案 > 编写一个采用任何可排序 spark 数据类型的聚合器

问题描述

我正在学习自定义 Spark 聚合器,并且我正在尝试实现一个“MinN”函数,该函数返回列中 N 个最小项的数组。我希望它适用于整数、小数和时间戳。

这仅适用于双打:

case class MinN(col: String, cutoff: Int = 5)
  extends Aggregator[Row, ArrayBuffer[Double], ArrayBuffer[Double]] with Serializable {

  def zero =  ArrayBuffer[Double]()
  def reduce(acc: ArrayBuffer[Double], x: Row) = {
    val curval = x.getAs[Double](col)
    if (acc.length < cutoff){
      acc.append(curval)
    } else {
      val maxOfMins = acc.max
      if (curval < maxOfMins) {
        acc(acc.indexOf(maxOfMins)) = curval
      }
    }
    acc
  }

  def merge(acc1: ArrayBuffer[Double], acc2: ArrayBuffer[Double]) = ({
    (acc1 ++ acc2).sorted.take(cutoff)
  })

  def finish(acc: ArrayBuffer[Double]) = acc

  override def bufferEncoder: Encoder[ArrayBuffer[Double]] = ExpressionEncoder()
  override def outputEncoder: Encoder[Option[Double]] = ExpressionEncoder()
}

然后,我尝试通过将声明更改为MinN[T : Ordering]、比较更改为implicitly[Ordering[T]].lt(curval, maxOfMins)以及所有[Double]s 更改为[T]s 来使聚合器通用。这给出了以下编译器错误:

Error:(58, 74) type mismatch;
 found   : org.apache.spark.sql.catalyst.encoders.ExpressionEncoder[Nothing]
 required: org.apache.spark.sql.Encoder[scala.collection.mutable.ArrayBuffer[T]]
Note: Nothing <: scala.collection.mutable.ArrayBuffer[T], but trait Encoder is invariant in type T.
You may wish to define T as +T instead. (SLS 4.5)
  override def bufferEncoder: Encoder[ArrayBuffer[T]] = ExpressionEncoder()

我觉得我在这里缺少一些基本的东西。我什至不想让MinN函数像那样参数化(所以调用者必须写MinN[Double]. 我想创建类似内置min函数的东西,它保留其输入的(火花)数据类型。

编辑

我正在使用这样的 MinN 聚合器:

  val minVolume = new MinN[Double]("volume").toColumn
  val p = dataframe.agg(minVolume.name("minVolume"))

标签: scalaapache-spark

解决方案


我相信 spark 无法处理这种高级抽象。您可以将聚合转换为这样的东西

case class MinN[T : Ordering](cutoff: Int = 5)(
  implicit arrEnc: Encoder[mutable.ArrayBuffer[T]])
  extends Aggregator[T, mutable.ArrayBuffer[T], mutable.ArrayBuffer[T]] with Serializable {

  def zero =  mutable.ArrayBuffer[T]()
  def reduce(acc: mutable.ArrayBuffer[T], x: T) = {
    mutable.ArrayBuffer.empty
  }

  def merge(acc1: mutable.ArrayBuffer[T], acc2: mutable.ArrayBuffer[T]) = ({
    mutable.ArrayBuffer.empty
  })

  def finish(acc: mutable.ArrayBuffer[T]) = acc

  override def bufferEncoder: Encoder[mutable.ArrayBuffer[T]] = implicitly
  override def outputEncoder: Encoder[mutable.ArrayBuffer[T]] = implicitly
}

并且将编译,您缺少编码器,因此它们在构造函数中被提取。但在如下示例中使用它:

val spark = SparkSession.builder().appName("jander").master("local[1]").getOrCreate()

import spark.implicits._

val custom = MinN[Double](2).toColumn

val d: Double = 1.1

val df = List(
  ("A", 1.1),
  ("A", 1.2),
  ("A", 1.3),
  ).toDF("col1", "col2")

df.groupBy("col1").agg(custom("col2") as "a").show()

运行时会抛出异常

Exception in thread "main" org.apache.spark.sql.AnalysisException: unresolved operator 'Aggregate [col1#10], [col1#10, minn(MinN(2), None, None, None, newInstance(class org.apache.spark.sql.catalyst.util.GenericArrayData) AS value#1, mapobjects(MapObjects_loopValue0, false, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue0, false, DoubleType, false)), input[0, array<double>, false], Some(class scala.collection.immutable.List)), newInstance(class org.apache.spark.sql.catalyst.util.GenericArrayData) AS value#0, StructField(value,ArrayType(DoubleType,false),false), true, 0, 0)[col2] AS a#16];;

推荐阅读