首页 > 解决方案 > 是否需要为 Spark jdbc 源添加 SaveMode Delete、Update 和 Upsert?

问题描述

您认为是否有必要为 Delete、Update 和 Upsert 添加 SaveMode?如:

参考代码:JdbcRelationProvider.scala

我已经分析了它的 SaveTable 代码:JdbcUtils.scala,并认为使用删除、更新和合并语句来扩展当前的插入实现很容易,例如:

  def getDeleteStatement(table: String, rddSchema: StructType, dialect: JdbcDialect): String = {
    val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name) + "=?").mkString(" AND ")

    s"DELETE FROM ${table.toUpperCase} WHERE $columns"
  }

  def getUpdateStatement(table: String, rddSchema: StructType, priKeys: Seq[String], dialect: JdbcDialect): String = {
    val fullCols = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name))
    val priCols = priKeys.map(dialect.quoteIdentifier(_))
    val columns = (fullCols diff priCols).map(_ + "=?").mkString(",")
    val cnditns = priCols.map(_ + "=?").mkString(" AND ")

    s"UPDATE ${table.toUpperCase} SET $columns WHERE $cnditns"
  }

  def getMergeStatement(table: String, rddSchema: StructType, priKeys: Seq[String], dialect: JdbcDialect): String = {
    val fullCols = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name))
    val priCols = priKeys.map(dialect.quoteIdentifier(_))
    val nrmCols = fullCols diff priCols

    val fullPart = fullCols.map(c => s"${dialect.quoteIdentifier("SRC")}.$c").mkString(",")
    val priPart = priCols.map(c => s"${dialect.quoteIdentifier("TGT")}.$c=${dialect.quoteIdentifier("SRC")}.$c").mkString(" AND ")
    val nrmPart = nrmCols.map(c => s"$c=${dialect.quoteIdentifier("SRC")}.$c").mkString(",")

    val columns = fullCols.mkString(",")
    val placeholders = fullCols.map(_ => "?").mkString(",")

    s"MERGE INTO ${table.toUpperCase} AS ${dialect.quoteIdentifier("TGT")} " +
      s"USING TABLE(VALUES($placeholders)) " +
      s"AS ${dialect.quoteIdentifier("SRC")}($columns) " +
      s"ON $priPart " +
      s"WHEN NOT MATCHED THEN INSERT ($columns) VALUES ($fullPart) " +
      s"WHEN MATCHED THEN UPDATE SET $nrmPart"
  }

唯一额外的事情是我们需要为它们提供主键,它们可以调用相同的savePartition函数,只需将 insertStmt 替换为 runningStmt 覆盖所有插入/删除/更新/合并。

  def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      insertStmt: String,
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int,
      options: JDBCOptions): Iterator[Byte] 

标签: scalaapache-sparkspark-jdbc

解决方案


推荐阅读