首页 > 解决方案 > spark 3.0- spark 聚合函数给出与预期不同的表达式

问题描述

/Downloads/spark-3.0.1-bin-hadoop2.7/bin$ ./spark-shell


20/09/23 10:58:45 WARN Utils: Your hostname, byte-nihal resolves to a loopback address: 127.0.1.1; using 192.168.2.103 instead (on interface enp2s0)
20/09/23 10:58:45 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
20/09/23 10:58:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Spark context Web UI available at http://192.168.2.103:4040
Spark context available as 'sc' (master = local[*], app id = local-1600838949311).
Spark session available as 'spark'.
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 3.0.1
      /_/
         
Using Scala version 2.12.10 (OpenJDK 64-Bit Server VM, Java 1.8.0_265)
Type in expressions to have them evaluated.
Type :help for more information.

scala> import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions._

scala> println(countDistinct("x"))
count(x)

scala> println(sumDistinct("x"))
sum(DISTINCT x)

scala> println(sum("x"))
sum(x)

scala> println(count("x"))
count(x)

问题:

这是某种错误还是功能?

火花 3.0 文档

注意:countDistinct 给出正确的表达式 -> count( Distinct x ) in spark version < 3.0

标签: apache-sparkapache-spark-sql

解决方案


正如评论部分中提到的@Shaido...我已经验证了一些事情来指出toString中最新版本的spark代码中存在一些错误。(这可能是我不完全确定的错误或功能)

spark 代码版本 < 3.X 行为

import org.apache.spark.sql.functions._

println(countDistinct("x")) ---> gives output as  count(x)

如果我们特别检查 countDistinct("x") 的源代码

  def countDistinct(columnName: String, columnNames: String*): Column =
    countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
 
  def countDistinct(expr: Column, exprs: Column*): Column = {
    withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true)
  }

正如您在第二个重载方法中看到的那样,使用了Count.apply聚合函数,并且isDistinct=true 将其计为不同的值

private def withAggregateFunction(
    func: AggregateFunction,
    isDistinct: Boolean = false): Column = {
    Column(func.toAggregateExpression(isDistinct))
  }

如果您特别检查withAggregateFunction签名,它会返回Column类型,并且如果您检查 Column 的 toString 方法

 def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql

它在AggregateExpression上调用.sql方法

根据以下代码,AggregateExpression 回调 aggregateFunction 的 sql 方法
override def sql: String = aggregateFunction.sql(isDistinct)

在我们的例子中 AggregateFuncion 是 Count

def sql(isDistinct: Boolean): String = {
    val distinct = if (isDistinct) "DISTINCT " else ""
    s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
  }

根据上面的代码,它应该返回 count(DISTINCT x)

现在,在 spark 版本 >= 3.X 中,我检查了源代码,toString 行为略有不同。

@scala.annotation.varargs
  def countDistinct(expr: Column, exprs: Column*): Column =
    // For usage like countDistinct("*"), we should let analyzer expand star and
    // resolve function.
    Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true))

它现在使用 UnresolvedFunction 而不是 withAggregateFunction。

UnresolvedFunction中,toString 方法非常简单,如下所示

override def toString: String = s"'$name(${children.mkString(", ")})"

打印 count(x) .. 这就是为什么你得到输出为 count(x)


推荐阅读