首页 > 解决方案 > 关于重构 Scala 的建议 - 我可以消除在 foreach 循环中使用的变量吗

问题描述

寻找一些关于如何重构一些 Scala 代码以使其更优雅、更惯用的 Scala 的建议。

我有一个功能

def joinDataFramesOnColumns(joinColumns: Seq[String]) : org.apache.spark.sql.DataFrame

通过在 aSeq[org.apache.spark.sql.DataFrame]上将它们连接在一起来对 a 进行操作joinColumns。这是函数定义:

implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
    def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
      val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
      val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
      if (nonEmptyDataFrames.isEmpty){
        emptyDataFrame
      }
      else {
        if (joinColumns.isEmpty) {
          return nonEmptyDataFrames.reduce(_.crossJoin(_))
        }
      nonEmptyDataFrames.reduce(_.join(_, joinColumns))
    }
  }
}

我有一些单元测试都成功了:

class FeatureGeneratorDataFrameExtensionsTest extends WordSpec {
  val fruitValues = Seq(
    Row(0, "BasketA", "Bananas", "Jack"),
    Row(2, "BasketB", "Oranges", "Jack"),
    Row(2, "BasketC", "Oranges", "Jill"),
    Row(3, "BasketD", "Oranges", "Jack"),
    Row(4, "BasketE", "Oranges", "Jack"),
    Row(4, "BasketE", "Apples", "Jack"),
    Row(4, "BasketF", "Bananas", "Jill")
  )
  val schema = List(
    StructField("weeksPrior", IntegerType, true),
    StructField("basket", StringType, true),
    StructField("Product", StringType, true),
    StructField("Customer", StringType, true)
  )
  val fruitDf = spark.createDataFrame(
    spark.sparkContext.parallelize(fruitValues),
    StructType(schema)
  ).withColumn("Date", udfDateSubWeeks(lit(dayPriorToAsAt), col("weeksPrior")))

  "FeatureGenerator.SequenceOfDataFrames" should {
    "join multiple dataframes on a specified set of columns" in {
      val sequenceOfDataFrames = Seq[DataFrame](
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior1"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior2"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior3"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior4"),
        fruitDf.withColumnRenamed("weeksPrior", "weeksPrior5")
      )
      val joinedDataFrames = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product", "Customer", "Date"))
      assert(joinedDataFrames.columns.length === 9)
      assert(joinedDataFrames.columns.contains("basket"))
      assert(joinedDataFrames.columns.contains("Product"))
      assert(joinedDataFrames.columns.contains("Customer"))
      assert(joinedDataFrames.columns.contains("Date"))
      assert(joinedDataFrames.columns.contains("weeksPrior1"))
      assert(joinedDataFrames.columns.contains("weeksPrior2"))
      assert(joinedDataFrames.columns.contains("weeksPrior3"))
      assert(joinedDataFrames.columns.contains("weeksPrior4"))
      assert(joinedDataFrames.columns.contains("weeksPrior5"))
    }
    "when passed a list of one dataframe return that same dataframe" in {
      val sequenceOfDataFrames = Seq[DataFrame](fruitDf)
      val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product"))
      assert(joinedDataFrame.columns.sorted === fruitDf.columns.sorted)
      assert(joinedDataFrame.count === fruitDf.count)
    }
    "when passed an empty list of dataframes return an empty dataframe" in {
      val joinedDataFrame = Seq[DataFrame]().joinDataFramesOnColumns(Seq("basket"))
      assert(joinedDataFrame === spark.emptyDataFrame)
    }
    "when passed an empty list of joinColumns return the dataframes crossjoined" in {
      val sequenceOfDataFrames = Seq[DataFrame](fruitDf,fruitDf, fruitDf)
      val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq[String]())
      assert(joinedDataFrame.count === scala.math.pow(fruitDf.count, sequenceOfDataFrames.size))
      assert(joinedDataFrame.columns.size === fruitDf.columns.size * sequenceOfDataFrames.size)
    }
  }
}

这一切都很好,直到由于此 Spark 错误而开始出错:https ://issues.apache.org/jira/browse/SPARK-25150当连接列具有相同名称时,在某些情况下可能会导致错误。

解决方法是将列别名为其他内容,因此我重写了这样的函数,该函数为连接列取别名,进行连接,然后将它们重命名:

  implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
    def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
      val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
      val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
      if (nonEmptyDataFrames.isEmpty){
        emptyDataFrame
      }
      else {
        if (joinColumns.isEmpty) {
          return nonEmptyDataFrames.reduce(_.crossJoin(_))
        }

      /*
      The horrible, gnarly, unelegent code below  would ideally exist simply as:

      nonEmptyDataFrames.reduce(_.join(_, joinColumns))

      however that will fail in certain specific circumstances due to a bug in spark,
      see https://issues.apache.org/jira/browse/SPARK-25150 for details
       */
      val aliasSuffix = "_aliased"
      val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
      var aliasedNonEmptyDataFrames: Seq[DataFrame] = Seq()
      nonEmptyDataFrames.foreach(
        nonEmptyDataFrame =>{
          var tempNonEmptyDataFrame = nonEmptyDataFrame
          joinColumns.foreach(
            joinColumn => {
              tempNonEmptyDataFrame = tempNonEmptyDataFrame.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
            }
          )
          aliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames :+ tempNonEmptyDataFrame
        }
      )
      var joinedAliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames.reduce(_.join(_, aliasedJoinColumns))
      joinColumns.foreach(
        joinColumn => joinedAliasedNonEmptyDataFrames = joinedAliasedNonEmptyDataFrames.withColumnRenamed(
          joinColumn+aliasSuffix, joinColumn
        )
      )
      joinedAliasedNonEmptyDataFrames
    }
  }
}

测试仍然通过,所以我对它相当满意,但我正在查看那些vars 和在每次迭代中将结果分配回该结果的var循环......并发现它们相当不优雅,相当丑陋,尤其是与原始版本相比函数的版本。我觉得必须有一种方法来写这个,这样我就不必使用vars,但经过一些试验和错误,这是我能做的最好的。

谁能提出一个更优雅的解决方案?作为一个新手 Scala 开发人员,它真的会帮助我更加熟悉解决此类问题的惯用方法。

对其余代码(例如测试)的任何建设性评论也将受到欢迎

标签: scalaapache-spark

解决方案


感谢@Duelist,他对使用 foldLeft() 的建议使我了解了 Scala 中的 foldLeft 如何在 DataFrame 上工作?这反过来又导致我像这样调整我的代码以消除vars:

  implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
    def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
      val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
      val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
      if (nonEmptyDataFrames.isEmpty){
        emptyDataFrame
      }
      else {
        if (joinColumns.isEmpty) {
          return nonEmptyDataFrames.reduce(_.crossJoin(_))
        }

        /*
        The code below  would ideally exist simply as:

        nonEmptyDataFrames.reduce(_.join(_, joinColumns))

        however that will fail in certain specific circumstances due to a bug in spark,
        see https://issues.apache.org/jira/browse/SPARK-25150 for details

        hence this code aliases the joinColumns, performs the join, then renames the 
        aliased columns back to their original name
         */
        val aliasSuffix = "_aliased"
        val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
        val joinedAliasedNonEmptyDataFrames = nonEmptyDataFrames.foldLeft(Seq[DataFrame]()){
          (tempDf, nonEmptyDataFrame) => tempDf :+ joinColumns.foldLeft(nonEmptyDataFrame){
            (tempDf2, joinColumn) => tempDf2.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
          }
        }.reduce(_.join(_, aliasedJoinColumns))
        joinColumns.foldLeft(joinedAliasedNonEmptyDataFrames){
          (tempDf, joinColumn) => tempDf.withColumnRenamed(joinColumn+aliasSuffix, joinColumn)
        }
      }
    }
  }

我本可以通过将两个语句合并为一个从而消除它来更进一步,val joinedAliasedNonEmptyDataFrames但我更喜欢使用该 interim 带来的清晰性val


推荐阅读