首页 > 解决方案 > 避免在 pyspark 中对多个连接进行洗牌和长期计划

问题描述

我正在使用同一个数据框进行多次连接,我加入的数据框是我原始数据框上 group by 的结果。

    listOfCols = ["a","b","c",....]
    for c in listOfCols:
        means=df.groupby(col(c)).agg(mean(target).alias(f"{c}_mean_encoding"))
        df=df.join(means,c,how="left")

这段代码产生了超过 100000 个任务并且需要很长时间才能完成。我看到在 dag 发生了很多洗牌。我该如何优化这段代码?

标签: apache-sparkpysparkapache-spark-sqlpyspark-dataframes

解决方案


好吧,经过多次尝试和失败,我想出了最快的解决方案。而不是这项工作的 1.5 小时,它运行了 5 分钟......我会把它放在这里,所以如果有人偶然发现它 - 他/她不会像我一样受苦......解决方案是使用 spark sql ,它必须比使用数据框 API 在内部进行更多优化:

df.registerTempTable("df")
for c in listOfCols:
    left_join_string  += f" left join means_{c} on df.{c} = means_{c}.{c}"
    means = df.groupby(F.col(c)).agg(F.mean(target).alias(f"{c}_mean_encoding"))
    means.registerTempTable(f"means_{c}")

df = sqlContext.sql("SELECT * FROM df "+left_join_string)

推荐阅读