首页 > 解决方案 > 优化shuffle spark外连接数据集

问题描述

我正在使用带有 DataFrames API 的 Spark 2.1 来做:

import org.apache.spark.sql.Encoders
import java.security.MessageDigest
import org.apache.spark.sql.functions._
import spark.implicits._
import org.apache.spark.sql.{Dataset, Encoders, SaveMode, SparkSession}

case class C(id_1: String, id_2: String, a: Option[Int], b: String)

val schema = Encoders.product[C]

val data1 = Seq(
  ("d1_r0", "d1_t0", 1, "yyy"),
  ("d1_r1", "d1_t1", 2, "xxx"),
  ("d2_r2", "d2_t2", 3, "ppp"),
  ("d1_r3", "d1_t3", 4, "iii")
)

val df1 = data1.toDF("id_1", "id_2", "a", "b")
val ds1: Dataset[C] = df1.as(schema)

val data2 = Seq(
  ("d2_r0", "d2_t0", 1, "lll"),
  ("d1_r1", "d1_t1", 2, "mmm"),
  ("d2_r2", "d2_t2", 3, "ppp"),
  ("d2_r3", "d2_t3", 4, "nnn")
)

val df2 = data2.toDF("id_1", "id_2", "a", "b")
val ds2: Dataset[C] = df2.as(schema)

def getMD5Hash(x: C): String = {
  val str = (x.id_1 + x.id_2 + x.a + x.b)
  val msgDigest: MessageDigest = MessageDigest.getInstance("MD5")
  val MD5Hash = msgDigest
    .digest(str.getBytes())
    .map(0xff & _)
    .map { "%02x".format(_) }
    .foldLeft("") { _ + _ }
  MD5Hash
}
def u(newV: C, oldV: C): Seq[C] = {
  Seq(C(oldV.id_1, oldV.id_2, oldV.a, newV.b))
}
def uOrI(b: String)(row: (C, C)): Seq[C] = {
  row match {
    case (newV, null) => Seq(newV)
    case (null, oldV) => Seq(C(oldV.id_1, oldV.id_2, oldV.a, b))
    case (newV, oldV) => {
      if (getMD5Hash(newV) == getMD5Hash(oldV)) Seq(oldV)
      else u(newV, oldV)
    }
  }
}

val df3 = ds1
  .joinWith(
    ds2,
    $"_1.id_1" === $"_2.id_1" && $"_1.id_2" === $"_2.id_2","full_outer"
  ).flatMap(uOrI("jjjjjjjj"))

该程序运行并产生了我所期望的,但在实际数据集中(df1 和 df2 超过 100 万行),解决方案非常慢,在具有 10 个节点(每个 16 cpu 128G ram)的集群纱线中完成 30 分钟。

还有另一种解决方案/想法可以优化洗牌和时间吗?

标签: scalaapache-sparkjoin

解决方案


推荐阅读