首页 > 解决方案 > 查找每 5 小时间隔的最小值

问题描述

我的

val df = Seq(
  ("1", 1),
  ("1", 1),
  ("1", 2),
  ("1", 4),
  ("1", 5),
  ("1", 6),
  ("1", 8),
  ("1", 12),
  ("1", 12),
  ("1", 13),
  ("1", 14),
  ("1", 15),
  ("1", 16)
).toDF("id", "time")

对于这种情况,第一个间隔从 1 小时开始。因此,最多 6 (1 + 5) 的每一行都是此间隔的一部分。

但是 8 - 1 > 5,所以第二个区间从 8 开始,一直到 13。

然后我看到 14 - 8 > 5,所以第三个开始,依此类推。

想要的结果

+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1  |1   |1       |
|1  |1   |1       |
|1  |2   |1       |
|1  |4   |1       |
|1  |5   |1       |
|1  |6   |1       |
|1  |8   |8       |
|1  |12  |8       |
|1  |12  |8       |
|1  |13  |8       |
|1  |14  |14      |
|1  |15  |14      |
|1  |16  |14      |
+---+----+--------+

我正在尝试使用 min 函数来做到这一点,但不知道如何解释这种情况。

val window = Window.partitionBy($"id").orderBy($"time")
df
.select($"id", $"time")
.withColumn("min_time", when(($"time" - min($"time").over(window)) <= 5, min($"time").over(window)).otherwise($"time"))
.show(false)

我得到了什么

+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1  |1   |1       |
|1  |1   |1       |
|1  |2   |1       |
|1  |4   |1       |
|1  |5   |1       |
|1  |6   |1       |
|1  |8   |8       |
|1  |12  |12      |
|1  |12  |12      |
|1  |13  |13      |
|1  |14  |14      |
|1  |15  |15      |
|1  |16  |16      |
+---+----+--------+

标签: scalaapache-sparkapache-spark-sql

解决方案


您可以使用在window上使用聚合函数的第一个想法。但是,您可以定义自己的Spark 的用户定义聚合函数(UDAF),而不是使用 Spark 已经定义的函数的某种组合。

分析

正如你所想的那样,我们应该在窗口上使用一种 min 函数。在此窗口的行上,我们要实现以下规则:

给定按 排序的行time,如果min_time上一行的time与当前行的 之差大于 5,则当前行min_time应该是当前行time,否则当前行min_time应该是上一行min_time

但是,使用 Spark 提供的聚合函数,我们无法访问前一行的min_time. 它存在一个lag函数,但使用这个函数,我们只能访问先前行的已经存在的值。由于前一行min_time不存在,我们无法访问它。

因此我们必须定义自己的聚合函数

解决方案

定义一个定制的聚合函数

要定义我们的聚合函数,我们需要创建一个扩展Aggregator抽象类的类。下面是完整的实现:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

object MinByInterval extends Aggregator[Integer, Integer, Integer] {

  def zero: Integer = null

  def reduce(buffer: Integer, time: Integer): Integer = {
    if (buffer == null || time - buffer > 5) time else buffer
  }

  def merge(b1: Integer, b2: Integer): Integer = {
    throw new NotImplementedError("should not use as general aggregation")
  }

  def finish(reduction: Integer): Integer = reduction

  def bufferEncoder: Encoder[Integer] = Encoders.INT

  def outputEncoder: Encoder[Integer] = Encoders.INT

}

我们Integer用于输入、缓冲区和输出类型。我们选择Integer它是因为它是可空的Int。我们本可以使用Option[Int],但是 Spark 的文档建议不要在聚合器方法中重新创建对象以解决性能问题,如果我们使用复杂类型(例如Option.

我们在方法中实现分析部分定义的规则reduce

def reduce(buffer: Integer, time: Integer): Integer = {
  if (buffer == null || time - buffer > 5) time else buffer
}

这里time是当前行的time列中buffer的值,也是之前计算的值,所以对应上一行的min_time列。在我们的窗口中,我们按 对行进行排序timetime总是大于buffer。空缓冲区情况仅在处理第一行时发生。

在窗口上使用聚合函数时不使用该方法merge,因此我们不实现它。

finish方法是标识方法,因为我们不需要对我们的聚合值执行最终计算,输出和缓冲区编码器是Encoders.INT

调用用户定义的聚合函数

现在我们可以使用以下代码调用用户定义的聚合函数:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}

val minTime = udaf(MinByInterval)
val window = Window.partitionBy("id").orderBy("time")
df.withColumn("min_time", minTime(col("time")).over(window))

给定问题中的输入数据框,我们得到:

+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1  |1   |1       |
|1  |1   |1       |
|1  |2   |1       |
|1  |4   |1       |
|1  |5   |1       |
|1  |6   |1       |
|1  |8   |8       |
|1  |12  |8       |
|1  |12  |8       |
|1  |13  |8       |
|1  |14  |14      |
|1  |15  |14      |
|1  |16  |14      |
+---+----+--------+

推荐阅读