首页 > 解决方案 > 如何有效地使用窗口函数根据 N 个先前值来决定接下来的 N 个行

问题描述

您好,我有以下数据。

+----------+----+-------+-----------------------+
|      date|item|avg_val|conditions             |
+----------+----+-------+-----------------------+
|01-10-2020|   x|     10|                      0|
|02-10-2020|   x|     10|                      0|
|03-10-2020|   x|     15|                      1|
|04-10-2020|   x|     15|                      1|
|05-10-2020|   x|      5|                      0|
|06-10-2020|   x|     13|                      1|
|07-10-2020|   x|     10|                      1|
|08-10-2020|   x|     10|                      0|
|09-10-2020|   x|     15|                      1|
|01-10-2020|   y|     10|                      0|
|02-10-2020|   y|     18|                      0|
|03-10-2020|   y|      6|                      1|
|04-10-2020|   y|     10|                      0|
|05-10-2020|   y|     20|                      0|
+----------+----+-------+-----------------------+

我正在尝试创建一个名为 flag level 的新列

  1. 如果标志值为 0,则新列值为 0。
  2. 如果标志为 1,则新列将为 1,接下来的四个 N 行数将为零,即无需检查下一个 N 值。此过程将应用于每个项目,即按项目分区将起作用。

我这里用过 N = 4,

我已经使用了下面的代码,但没有有效的窗口函数是否有任何优化的方法。

DROP TEMPORARY TABLE t2;
CREATE TEMPORARY TABLE t2
SELECT *,
MAX(conditions) OVER (PARTITION BY item ORDER BY item,`date` ROWS 4 PRECEDING ) AS new_row
FROM record
ORDER BY item,`date`;

 

 DROP TEMPORARY TABLE t3;
CREATE TEMPORARY TABLE t3
SELECT *,ROW_NUMBER() OVER (PARTITION BY item,new_row ORDER BY item,`date`) AS e FROM t2;

 


SELECT *,CASE WHEN new_row=1 AND e%5>1 THEN 0 
WHEN new_row=1 AND e%5=1 THEN 1 ELSE 0 END AS flag FROM t3;

输出如下

+----------+----+-------+-----------------------+-----+
|      date|item|avg_val|conditions             |flag |
+----------+----+-------+-----------------------+-----+
|01-10-2020|   x|     10|                      0|    0|
|02-10-2020|   x|     10|                      0|    0|
|03-10-2020|   x|     15|                      1|    1|
|04-10-2020|   x|     15|                      1|    0|
|05-10-2020|   x|      5|                      0|    0|
|06-10-2020|   x|     13|                      1|    0|
|07-10-2020|   x|     10|                      1|    0|
|08-10-2020|   x|     10|                      0|    0|
|09-10-2020|   x|     15|                      1|    1|
|01-10-2020|   y|     10|                      0|    0|
|02-10-2020|   y|     18|                      0|    0|
|03-10-2020|   y|      6|                      1|    1|
|04-10-2020|   y|     10|                      0|    0|
|05-10-2020|   y|     20|                      0|    0|
+----------+----+-------+-----------------------+-----+

但我无法获得输出,我尝试了更多。

标签: mysqlapache-sparkapache-spark-sqlwindowing

解决方案


正如评论中所建议的(@nbk 和 @Akina),您将需要某种迭代器来实现逻辑。使用 SparkSQL 和 Spark 2.4+ 版本,我们可以使用内置函数聚合并设置一个结构数组和一个计数器作为累加器。下面是一个名为的示例数据框和表record(假设conditions列中的值为01):

val df = Seq(
    ("01-10-2020", "x", 10, 0), ("02-10-2020", "x", 10, 0), ("03-10-2020", "x", 15, 1),
    ("04-10-2020", "x", 15, 1), ("05-10-2020", "x", 5, 0), ("06-10-2020", "x", 13, 1),
    ("07-10-2020", "x", 10, 1), ("08-10-2020", "x", 10, 0), ("09-10-2020", "x", 15, 1),
    ("01-10-2020", "y", 10, 0), ("02-10-2020", "y", 18, 0), ("03-10-2020", "y", 6, 1),
    ("04-10-2020", "y", 10, 0), ("05-10-2020", "y", 20, 0)
).toDF("date", "item", "avg_val", "conditions")

df.createOrReplaceTempView("record")

SQL:

spark.sql("""
  SELECT t1.item, m.*
  FROM (
    SELECT item,
      sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
    FROM record
    GROUP BY item
  ) as t1 LATERAL VIEW OUTER inline(
    aggregate(
      /* expr: set up array `dta` from the 2nd element to the last 
       *       notice that indices for slice function is 1-based, dta[i] is 0-based
       */
      slice(dta,2,size(dta)),
      /* start: set up and initialize `acc` to a struct containing two fields:
       * - dta: an array of structs with a single element dta[0]
       * - counter: number of rows after flag=1, can be from `0` to `N+1`
       */
      (array(dta[0]) as dta, dta[0].conditions as counter),
      /* merge: iterate through the `expr` using x and update two fields of `acc`
       * - dta: append values from x to acc.dta array using concat + array functions
       *        update flag using `IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)`
       * - counter: increment by 1 if acc.counter is between 1 and 4
       *            , otherwise set value to x.conditions
       */
      (acc, x) -> named_struct(
          'dta', concat(acc.dta, array(named_struct(
              'date', x.date,
              'avg_val', x.avg_val,
              'conditions', x.conditions,
              'flag', IF(acc.counter IN (0,5) and x.conditions = 1, 1, 0)
            ))),
          'counter', IF(acc.counter > 0 and acc.counter < 5, acc.counter+1, x.conditions)
        ),
      /* finish: retrieve acc.dta only and discard acc.counter */
      acc -> acc.dta
    )
  ) m
""").show(50)

结果:

+----+----------+-------+----------+----+
|item|      date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
|   x|01-10-2020|     10|         0|   0|
|   x|02-10-2020|     10|         0|   0|
|   x|03-10-2020|     15|         1|   1|
|   x|04-10-2020|     15|         1|   0|
|   x|05-10-2020|      5|         0|   0|
|   x|06-10-2020|     13|         1|   0|
|   x|07-10-2020|     10|         1|   0|
|   x|08-10-2020|     10|         0|   0|
|   x|09-10-2020|     15|         1|   1|
|   y|01-10-2020|     10|         0|   0|
|   y|02-10-2020|     18|         0|   0|
|   y|03-10-2020|      6|         1|   1|
|   y|04-10-2020|     10|         0|   0|
|   y|05-10-2020|     20|         0|   0|
+----+----------+-------+----------+----+

在哪里:

  1. 用于groupby将同一项目的行收集到一个名为dta列的结构数组中,其中包含 4 个字段:日期avg_val条件标志,并按日期排序
  2. 使用aggregatefunction遍历上面的structs数组,根据countercondition更新flag字段(详情见上面SQL代码注释)
  3. 使用Lateral VIEW内联函数从聚合函数中分解生成的结构数组

笔记:

(1) 建议的 SQL 是针对 N=4 的,这里我们有acc.counter IN (0,5)acc.counter < 5在 SQL 中。对于任何 N,将上述调整为:acc.counter IN (0,N+1)acc.counter < N+1,下图显示了N=2相同样本数据的结果:

+----+----------+-------+----------+----+
|item|      date|avg_val|conditions|flag|
+----+----------+-------+----------+----+
|   x|01-10-2020|     10|         0|   0|
|   x|02-10-2020|     10|         0|   0|
|   x|03-10-2020|     15|         1|   1|
|   x|04-10-2020|     15|         1|   0|
|   x|05-10-2020|      5|         0|   0|
|   x|06-10-2020|     13|         1|   1|
|   x|07-10-2020|     10|         1|   0|
|   x|08-10-2020|     10|         0|   0|
|   x|09-10-2020|     15|         1|   1|
|   y|01-10-2020|     10|         0|   0|
|   y|02-10-2020|     18|         0|   0|
|   y|03-10-2020|      6|         1|   1|
|   y|04-10-2020|     10|         0|   0|
|   y|05-10-2020|     20|         0|   0|
+----+----------+-------+----------+----+

(2) 我们dta[0]用来初始化acc其中包括其字段的值和数据类型。理想情况下,我们应该确保这些字段的数据类型正确,以便正确进行所有计算。例如计算时acc.counter,如果conditions是StringType,acc.counter+1将返回一个StringType,其值为DoubleType

spark.sql("select '2'+1").show()
+---------------------------------------+
|(CAST(2 AS DOUBLE) + CAST(1 AS DOUBLE))|
+---------------------------------------+
|                                    3.0|
+---------------------------------------+

acc.counter IN (0,5)将它们的值与使用或的整数进行比较时,可能会产生浮点错误acc.counter < 5。根据 OP 的反馈,这产生了不正确的结果,没有任何警告/错误消息。

  • 一种解决方法是在设置聚合函数的第二个参数时使用 CAST 指定确切的字段类型,以便在任何类型不匹配时报告错误,见下文:

    CAST((array(dta[0]), dta[0].conditions) as struct<dta:array<struct<date:string,avg_val:string,conditions:int,flag:int>>,counter:int>),
    
  • dta在创建列时强制类型的另一种解决方案,在此示例中,请参见int(conditions) as conditions以下代码:

    SELECT item,
      sort_array(collect_list(struct(date,avg_val,int(conditions) as conditions,conditions as flag))) as dta
    FROM record
    GROUP BY item
    
  • 我们还可以在计算中强制数据类型,例如,见int(acc.counter+1)下文:

    IF(acc.counter > 0 and acc.counter < 5, int(acc.counter+1), x.conditions)      
    

推荐阅读