首页 > 解决方案 > Spark Scala - Winsorize 组内的 DataFrame 列

问题描述

我正在为机器学习输入预处理数据,一个目标值列,称之为“价格”有很多异常值,而不是对整个集合的价格进行微调,我想在标有“product_category”的组内进行微调。还有其他功能,product_category 只是一个与价格相关的标签。

有一个 Scala stat 函数效果很好:

df_data.stat.approxQuantile("price", Array(0.01, 0.99), 0.00001)
// res19: Array[Double] = Array(3.13, 318.54)

不幸的是,它不支持计算组内的分位数。也不支持窗口分区。

df_data
    .groupBy("product_category")
    .approxQuantile($"price", Array(0.01, 0.99), 0.00001)

// error: value approxQuantile is not a member of
//   org.apache.spark.sql.RelationalGroupedDataset

为了替换超出该范围的值(即winsorizing),计算火花数据帧组内的 p01 和 p99 的最佳方法是什么?

我的数据集模式可以这样想象,它有超过 20MM 的行,“product_category”有大约 10K 个不同的标签,所以性能也是一个问题。

df_data and a winsorized price column:
+---------+------------------+--------+---------+
|   item  | product_category |  price | pr_winz |
+---------+------------------+--------+---------+
| I000001 |     XX11         |   1.99 |   5.00  |
| I000002 |     XX11         |  59.99 |  59.99  |
| I000003 |     XX11         |1359.00 | 850.00  |
+---------+------------------+--------+---------+
supposing p01 = 5.00, p99 = 850.00 for this product_category 

标签: scalaapache-sparkstatisticsdata-science

解决方案


这是我在努力研究文档后想出的(有两个功能approx_percentilepercentile_approx显然做同样的事情)。

除了作为 spark sql 表达式之外,我无法弄清楚如何实现这一点,不确定为什么分组只在那里有效。我怀疑是因为它是 Hive 的一部分?

Spark DataFrame Winsorizo​​r
  • 在 10 到 100MM 行范围内的 DF 上进行测试
// Winsorize function, groupable by columns list
// low/hi element of [0,1]
// precision: integer in [1, 1E7-ish], in practice use 100 or 1000 for large data, smaller is faster/less accurate
// group_col: comma-separated list of column names
import org.apache.spark.sql._

def grouped_winzo(df: DataFrame, winz_col: String, group_col: String, low: Double, hi: Double, precision: Integer): DataFrame = {
    df.createOrReplaceTempView("df_table")
    
    spark.sql(s"""
    select distinct 
    *
    , percentile_approx($winz_col, $low, $precision) over(partition by $group_col) p_low
    , percentile_approx($winz_col, $hi, $precision) over(partition by $group_col) p_hi
    
    from df_table
    """)
    .withColumn(winz_col + "_winz", expr(s"""
        case when $winz_col <= p_low then p_low
             when $winz_col >= p_hi then p_hi
             else $winz_col end"""))
    .drop(winz_col, "p_low", "p_hi")
    
}

// winsorize the price column of a dataframe at the p01 and p99 
// percentiles, grouped by 'product_category' column.

val df_winsorized = grouped_winzo(
   df_data
   , "price"
   , "product_category"
   , 0.01, 0.99, 1000)

推荐阅读