首页 > 解决方案 > Scala-Spark:过滤 DataFrame 性能和优化

问题描述

我想要实现的目标非常简单:我想检查所有 ID(uuid)是否经历了某种“状态”(行为状态)。如果有,则将与该 ID 关联的所有记录返回给我。例如,如果以下 ID 之一的状态为“三”,我希望保留与该 ID 关联的所有记录。到目前为止,我可以通过以下两种方式实现这一目标:

// first method
val dfList = df.filter($"status" === "three").select($"id").distinct.map(_.getString(0)).collect.toList
val dfTransformedOne = df.filter($"id".isin(dfList:_*))

// second method
val dfIds = df.filter($"status" === "three").select($"id").distinct
val dfTransformedTwo = df.join(broadcast(dfIds), Seq("id"))

上述两种方法适用于我正在使用的示例数据,但是当我开始增加要处理的数据量时遇到了一些性能问题,因为我可能需要数百万到数亿个 ID过滤。是否有更有效的方法来执行上述操作,或者只是增加我正在使用的硬件的情况?

以下是示例数据和预期输出。

val df = Seq(
  ("1234", "one"), 
  ("1234", "two"), 
  ("1234", "three"), 
  ("234", "one"), 
  ("234", "one"), 
  ("234", "two")
  ).toDF("id", "status")

df.show
+----+------+
|  id|status|
+----+------+
|1234|   one|
|1234|   two|
|1234| three|
| 234|   one|
| 234|   one|
| 234|   two|
+----+------+

dfTransformed.show()
+----+------+
|  id|status|
+----+------+
|1234|   one|
|1234|   two|
|1234| three|
+----+------+

标签: apache-sparkapache-spark-sql

解决方案


在过滤之前进行分组和聚合将引入一个 shuffle,同时消除了将大列表收集到驱动程序的需要。是否更快取决于您的数据分布、集群大小和网络连接。不过,这可能值得一试:

val df = Seq(
  ("1234", "one"), 
  ("1234", "two"), 
  ("1234", "three"), 
  ("234", "one"), 
  ("234", "one"), 
  ("234", "two")
  ).toDF("id", "status")

df.groupBy("id")
  .agg(collect_list("status").as("statuses"))
  .filter(array_contains($"statuses", "three"))
  .withColumn("status", explode($"statuses"))
  .select("id", "status")
  .show(false)

给出预期的结果:

+----+------+
|id  |status|
+----+------+
|1234|one   |
|1234|two   |
|1234|three |
+----+------+

推荐阅读