首页 > 解决方案 > Spark如何计算字符串列的均值和标准差

问题描述

我有以下数据(仅显示一个片段)

DEST_COUNTRY_NAME   ORIGIN_COUNTRY_NAME count
United States   Romania 15
United States   Croatia 1
United States   Ireland 344
Egypt   United States   15

我用inferSchema选项设置为true然后describe是列来阅读它。它似乎工作正常。

scala> val data = spark.read.option("header", "true").option("inferSchema","true").csv("./data/flight-data/csv/2015-summary.csv")
scala> data.describe().show()
+-------+-----------------+-------------------+------------------+
|summary|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|             count|
+-------+-----------------+-------------------+------------------+
|  count|              256|                256|               256|
|   mean|             null|               null|       1770.765625|
| stddev|             null|               null|23126.516918551915|
|    min|          Algeria|             Angola|                 1|
|    max|           Zambia|            Vietnam|            370002|
+-------+-----------------+-------------------+------------------+

如果我不指定inferSchema,则所有列都被视为字符串。

scala> val dataNoSchema = spark.read.option("header", "true").csv("./data/flight-data/csv/2015-summary.csv")
dataNoSchema: org.apache.spark.sql.DataFrame = [DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string ... 1 more field]

scala> dataNoSchema.printSchema
root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: string (nullable = true)

问题1)为什么最后一列Spark给出meanstddevcount

scala> dataNoSchema.describe().show();
+-------+-----------------+-------------------+------------------+
|summary|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|             count|
+-------+-----------------+-------------------+------------------+
|  count|              256|                256|               256|
|   mean|             null|               null|       1770.765625|
| stddev|             null|               null|23126.516918551915|
|    min|          Algeria|             Angola|                 1|
|    max|           Zambia|            Vietnam|               986|
+-------+-----------------+-------------------+------------------+

问题 2)如果Spark现在解释countnumeric列,那么为什么max值是 986 而不是 37002(就像在数据 DataFrame 中一样)

标签: apache-spark

解决方案


Spark SQL 渴望符合 SQL 标准,因此使用相同的评估规则,并且如果需要,透明地强制类型以满足表达式(例如,参见PySpark DataFrames 的回答 - 使用不同类型的列之间的比较进行过滤)。

这意味着maxand mean/ stddevcase 根本不等效:

  • maximum 对字符串有意义(使用字典顺序)并且不需要强制

    Seq.empty[String].toDF("count").agg(max("count")).explain
    
    == Physical Plan ==
    SortAggregate(key=[], functions=[max(count#69)])
    +- Exchange SinglePartition
       +- SortAggregate(key=[], functions=[partial_max(count#69)])
          +- LocalTableScan <empty>, [count#69]
    
  • 平均值或标准偏差不是,并且参数被转换为 double

    Seq.empty[String].toDF("count").agg(mean("count")).explain
    
    == Physical Plan ==
    *(2) HashAggregate(keys=[], functions=[avg(cast(count#81 as double))])
    +- Exchange SinglePartition
       +- *(1) HashAggregate(keys=[], functions=[partial_avg(cast(count#81 as double))])
          +- LocalTableScan <empty>, [count#81].
    

推荐阅读