首页 > 解决方案 > Spark-SQL 中的窗口函数崩溃

问题描述

使用 Spark 2.3.2 和 Spark-SQL,以下查询“b”失败:

import spark.implicits._

val dataset = Seq((30, 2.0), (20, 3.0), (19, 20.0)).toDF("age", "size")

import functions._
val a0 = dataset.withColumn("rank", rank() over Window.partitionBy('age).orderBy('size))
val a1 = a0.agg(avg('rank))
//a1.show()
//OK

//same thing but in one expression, crashes:
val b = dataset.agg(functions.avg(functions.rank().over(Window.partitionBy('age).orderBy('size))))

AFAIK 这很奇怪,但这是一个合法的 SQL 查询:

通过使用中间列来执行此操作,但在单个表达式中执行此操作会使催化剂因堆栈溢出而崩溃:

Exception in thread "main" java.lang.StackOverflowError
at scala.Option.orElse(Option.scala:289)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1.apply(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1.apply(TreeNode.scala:109)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.catalyst.trees.TreeNode.find(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1$$anonfun$apply$1.apply(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1$$anonfun$apply$1.apply(TreeNode.scala:109)
at scala.Option.orElse(Option.scala:289)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1.apply(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1.apply(TreeNode.scala:109)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.catalyst.trees.TreeNode.find(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1$$anonfun$apply$1.apply(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1$$anonfun$apply$1.apply(TreeNode.scala:109)
at scala.Option.orElse(Option.scala:289)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1.apply(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$find$1.apply(TreeNode.scala:109)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.catalyst.trees.TreeNode.find(TreeNode.scala:109)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions$.org$apache$spark$sql$catalyst$analysis$Analyzer$ExtractWindowExpressions$$hasWindowFunction(Analyzer.scala:1757)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions$$anonfun$71.apply(Analyzer.scala:1781)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions$$anonfun$71.apply(Analyzer.scala:1781)
at scala.collection.TraversableLike$$anonfun$partition$1.apply(TraversableLike.scala:314)
at scala.collection.TraversableLike$$anonfun$partition$1.apply(TraversableLike.scala:314)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.TraversableLike$class.partition(TraversableLike.scala:314)
at scala.collection.AbstractTraversable.partition(Traversable.scala:104)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions$.org$apache$spark$sql$catalyst$analysis$Analyzer$ExtractWindowExpressions$$extract(Analyzer.scala:1781)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions$$anonfun$apply$28.applyOrElse(Analyzer.scala:1950)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions$$anonfun$apply$28.applyOrElse(Analyzer.scala:1925)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272) [...]

这是一个已知的问题?我不是 100% 确定我的查询是正确的,但它至少不应该让催化剂崩溃,因为它甚至在我能够评估我的查询之前就崩溃了

标签: apache-sparkapache-spark-sql

解决方案


avg() 将列名作为参数,但 rank() 正在传递实际的列数据,因此它找不到列名。

它的工作原理与:

dataset.select( rank().over(Window.partitionBy("age").orderBy("size")).as("rank_XX") 
     ).
     agg(avg("rank_XX")).show() 

输出:

+------------+                                                                  
|avg(rank_XX)|
+------------+
|         1.0|
+------------+

推荐阅读