首页 > 解决方案 > Spark sortMergeJoin 不会更改为 shuffleHashJoin

问题描述

我试图通过禁用 BroadcastHashJoin 和 SortMergeJoin 来强制 spark 使用 ShuffleHashJoin,但 spark 总是使用 SortMergeJoin。

我正在使用火花版本 2.4.3

object ShuffleHashJoin {

def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)

val spark = SparkSession.builder()
  .appName("ShuffleHashJoin")
  .master("local[*]")
  .getOrCreate()

/*
* Disable auto broadcasting of table and SortMergeJoin
*/
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 0)
spark.conf.set("spark.sql.join.preferSortMergeJoin", false)

import spark.implicits._
val dataset = Seq(
  (0, "playing"),
  (1, "with"),
  (2, "ShuffledHashJoinExec")).toDF("id", "token")

dataset.join(dataset, Seq("id"), "inner").foreach(_ => ())

// infinite loop to keep the program running to check Spark UI at 4040 port.
while (true) {}

标签: apache-sparkapache-spark-sql

解决方案


除了设置 spark.sql.join.preferSortMergeJoinfalseSpark 之外,还必须验证以下内容:(源代码

  1. 单个分区应该足够小以构建哈希表
canBuildLocalHashMap(right || left)
  |-> plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions

您已以编程方式设置spark.sql.autoBroadcastJoinThreshold为 0,因此它始终计算为false.

  1. 连接的一侧比另一侧小得多

    构建 hash map 的成本比排序要高,我们应该只在一个比另一个小得多的 table 上构建 hash map。由于我们没有行数的统计数据,因此在这里使用字节大小作为估计。

muchSmaller(right, left) || muchSmaller(left, right) 
 |-> a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes

在您的示例中,我们需要做几件事才能使其正常工作:

  1. 将自动广播阈值更改为一些较小的值 spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 2)

  2. 使连接的一侧至少x3更大

和工作示例:

    spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 2)
    spark.conf.set("spark.sql.join.preferSortMergeJoin", false)

    import spark.implicits._
    val dataset = Seq(
      (0, "playing"),
      (1, "with"),
      (2, "ShuffledHashJoinExec")).toDF("id", "token")

    val right = Seq(
      (0, "asdfghjklzxcvb"),
      (1, "asdfghjklzxcvb"),
      (2, "asdfghjklzxcvb"),
      (3, "asdfghjklzxcvb"),
      (4, "asdfghjklzxcvb"),
      (5, "asdfghjklzxcvb"),
      (6, "asdfghjklzxcvb"),
      (7, "asdfghjklzxcvb"),
      (8, "asdfghjklzxcvb"),
      (9, "asdfghjklzxcvb"),
    )
      .toDF("id", "token")

    val joined = dataset.join(right, Seq("id"), "inner")
    joined.explain(true)

*(1) Project [id#5, token#6, token#15]
+- ShuffledHashJoin [id#5], [id#14], Inner, BuildLeft
   :- Exchange hashpartitioning(id#5, 200)
   :  +- LocalTableScan [id#5, token#6]
   +- Exchange hashpartitioning(id#14, 200)
      +- LocalTableScan [id#14, token#15]

推荐阅读