scala - 关于重构 Scala 的建议 - 我可以消除在 foreach 循环中使用的变量吗
问题描述
寻找一些关于如何重构一些 Scala 代码以使其更优雅、更惯用的 Scala 的建议。
我有一个功能
def joinDataFramesOnColumns(joinColumns: Seq[String]) : org.apache.spark.sql.DataFrame
通过在 aSeq[org.apache.spark.sql.DataFrame]
上将它们连接在一起来对 a 进行操作joinColumns
。这是函数定义:
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
}
}
}
我有一些单元测试都成功了:
class FeatureGeneratorDataFrameExtensionsTest extends WordSpec {
val fruitValues = Seq(
Row(0, "BasketA", "Bananas", "Jack"),
Row(2, "BasketB", "Oranges", "Jack"),
Row(2, "BasketC", "Oranges", "Jill"),
Row(3, "BasketD", "Oranges", "Jack"),
Row(4, "BasketE", "Oranges", "Jack"),
Row(4, "BasketE", "Apples", "Jack"),
Row(4, "BasketF", "Bananas", "Jill")
)
val schema = List(
StructField("weeksPrior", IntegerType, true),
StructField("basket", StringType, true),
StructField("Product", StringType, true),
StructField("Customer", StringType, true)
)
val fruitDf = spark.createDataFrame(
spark.sparkContext.parallelize(fruitValues),
StructType(schema)
).withColumn("Date", udfDateSubWeeks(lit(dayPriorToAsAt), col("weeksPrior")))
"FeatureGenerator.SequenceOfDataFrames" should {
"join multiple dataframes on a specified set of columns" in {
val sequenceOfDataFrames = Seq[DataFrame](
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior1"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior2"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior3"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior4"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior5")
)
val joinedDataFrames = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product", "Customer", "Date"))
assert(joinedDataFrames.columns.length === 9)
assert(joinedDataFrames.columns.contains("basket"))
assert(joinedDataFrames.columns.contains("Product"))
assert(joinedDataFrames.columns.contains("Customer"))
assert(joinedDataFrames.columns.contains("Date"))
assert(joinedDataFrames.columns.contains("weeksPrior1"))
assert(joinedDataFrames.columns.contains("weeksPrior2"))
assert(joinedDataFrames.columns.contains("weeksPrior3"))
assert(joinedDataFrames.columns.contains("weeksPrior4"))
assert(joinedDataFrames.columns.contains("weeksPrior5"))
}
"when passed a list of one dataframe return that same dataframe" in {
val sequenceOfDataFrames = Seq[DataFrame](fruitDf)
val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product"))
assert(joinedDataFrame.columns.sorted === fruitDf.columns.sorted)
assert(joinedDataFrame.count === fruitDf.count)
}
"when passed an empty list of dataframes return an empty dataframe" in {
val joinedDataFrame = Seq[DataFrame]().joinDataFramesOnColumns(Seq("basket"))
assert(joinedDataFrame === spark.emptyDataFrame)
}
"when passed an empty list of joinColumns return the dataframes crossjoined" in {
val sequenceOfDataFrames = Seq[DataFrame](fruitDf,fruitDf, fruitDf)
val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq[String]())
assert(joinedDataFrame.count === scala.math.pow(fruitDf.count, sequenceOfDataFrames.size))
assert(joinedDataFrame.columns.size === fruitDf.columns.size * sequenceOfDataFrames.size)
}
}
}
这一切都很好,直到由于此 Spark 错误而开始出错:https ://issues.apache.org/jira/browse/SPARK-25150当连接列具有相同名称时,在某些情况下可能会导致错误。
解决方法是将列别名为其他内容,因此我重写了这样的函数,该函数为连接列取别名,进行连接,然后将它们重命名:
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
/*
The horrible, gnarly, unelegent code below would ideally exist simply as:
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
however that will fail in certain specific circumstances due to a bug in spark,
see https://issues.apache.org/jira/browse/SPARK-25150 for details
*/
val aliasSuffix = "_aliased"
val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
var aliasedNonEmptyDataFrames: Seq[DataFrame] = Seq()
nonEmptyDataFrames.foreach(
nonEmptyDataFrame =>{
var tempNonEmptyDataFrame = nonEmptyDataFrame
joinColumns.foreach(
joinColumn => {
tempNonEmptyDataFrame = tempNonEmptyDataFrame.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
}
)
aliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames :+ tempNonEmptyDataFrame
}
)
var joinedAliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames.reduce(_.join(_, aliasedJoinColumns))
joinColumns.foreach(
joinColumn => joinedAliasedNonEmptyDataFrames = joinedAliasedNonEmptyDataFrames.withColumnRenamed(
joinColumn+aliasSuffix, joinColumn
)
)
joinedAliasedNonEmptyDataFrames
}
}
}
测试仍然通过,所以我对它相当满意,但我正在查看那些var
s 和在每次迭代中将结果分配回该结果的var
循环......并发现它们相当不优雅,相当丑陋,尤其是与原始版本相比函数的版本。我觉得必须有一种方法来写这个,这样我就不必使用var
s,但经过一些试验和错误,这是我能做的最好的。
谁能提出一个更优雅的解决方案?作为一个新手 Scala 开发人员,它真的会帮助我更加熟悉解决此类问题的惯用方法。
对其余代码(例如测试)的任何建设性评论也将受到欢迎
解决方案
感谢@Duelist,他对使用 foldLeft() 的建议使我了解了 Scala 中的 foldLeft 如何在 DataFrame 上工作?这反过来又导致我像这样调整我的代码以消除var
s:
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
/*
The code below would ideally exist simply as:
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
however that will fail in certain specific circumstances due to a bug in spark,
see https://issues.apache.org/jira/browse/SPARK-25150 for details
hence this code aliases the joinColumns, performs the join, then renames the
aliased columns back to their original name
*/
val aliasSuffix = "_aliased"
val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
val joinedAliasedNonEmptyDataFrames = nonEmptyDataFrames.foldLeft(Seq[DataFrame]()){
(tempDf, nonEmptyDataFrame) => tempDf :+ joinColumns.foldLeft(nonEmptyDataFrame){
(tempDf2, joinColumn) => tempDf2.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
}
}.reduce(_.join(_, aliasedJoinColumns))
joinColumns.foldLeft(joinedAliasedNonEmptyDataFrames){
(tempDf, joinColumn) => tempDf.withColumnRenamed(joinColumn+aliasSuffix, joinColumn)
}
}
}
}
我本可以通过将两个语句合并为一个从而消除它来更进一步,val joinedAliasedNonEmptyDataFrames
但我更喜欢使用该 interim 带来的清晰性val
。
推荐阅读
- asp.net-mvc - 使用 Owin 进行 Twitter 外部登录会在回调时提供 HTTP 403(禁止)
- java - 持续时间中的“PT”前缀代表什么?
- sharepoint - 如何在 Sharepoint 列表发送电子邮件工作流程中编写 URL?
- css - HTML tr 用于宽度但隐藏
- python - SyntaxError:使用 ast.parse 扫描字符串文字时 EOL
- python - 使用 python-docx 突出显示 docx 文件中的段落
- c# - MongoDB c# find() 返回 _id 的 objectid 但它是文档中的字符串
- c# - 互斥量从最小化中恢复
- sql - 当我通过日记条目中的 sql 查询导入条目时,余额显示为 00.0。奇多10
- powershell - 10 小时后长时间运行的 Exchange PowerShell 脚本中出现 Kerberos 错误