首页 > 解决方案 > 在分组数据上使用 PySpark Imputer

问题描述

我有一Class列可以是 1、2 或 3,另一列Age缺少一些数据。我想估算每组的Age平均值Class

我想做一些事情:

grouped_data = df.groupBy('Class')
imputer = Imputer(inputCols=['Age'], outputCols=['imputed_Age'])
imputer.fit(grouped_data)

有什么解决方法吗?

谢谢你的时间

标签: pysparkmissing-data

解决方案


使用 Imputer,您可以将数据集过滤到每个Class值,估算平均值,然后将它们连接回来,因为您提前知道这些值可以是什么:

subsets = []
for i in range(1, 4):
    imputer = Imputer(inputCols=['Age'], outputCols=['imputed_Age'])
    subset_df = df.filter(col('Class') == i)
    imputed_subset = imputer.fit(subset_df).transform(subset_df)
    subsets.append(imputed_subset)
# Union them together
# If you only have 3 just do it without a loop
imputed_df = subsets[0].unionByName(subsets[1]).unionByName(subsets[2])

如果您不提前知道这些值是什么,或者它们不容易迭代,您可以分组,将每个组的平均值作为数据帧获取,然后将其合并回原始数据帧。

import pyspark.sql.functions as F
averages = df.groupBy("Class").agg(F.avg("Age").alias("avgAge"))
df_with_avgs = df.join(averages, on="Class")
imputed_df = df_with_avgs.withColumn("imputedAge", F.coalesce("Age", "avgAge"))

推荐阅读