首页 > 解决方案 > 如何评估包装在单一方法中的 pyspark 转换?

问题描述

我正在尝试组织在 pyspark 中执行的几个数据转换。我有类似于下面的代码。

def main():
    spark_session = SparkSession\
        .builder\
        .appName(config.SPARK_CONFIG['AppName']) \
        .getOrCreate()
    data = getData(spark_session)

    analytics = Analytics(data)
    analytics.execute_and_save_analytics()

    spark_session.stop()


def getData(spark_session):    
    sqlContext = pyspark.SQLContext(spark_session.sparkContext)
    return sqlContext.read.option('user', user).option('password', pswd)\
        .jdbc('jdbc:sqlserver://' + sqlserver + ':' + port\
        + ';database=' + database, table)


class Analytics():
    def __init__(self, df):
        self.df = df

    def _execute(self):
        df0 = self.df.withColumn('col3', df.col31 + df.col32)
        # df0.persist()
        df1 = df0.filter(df.col3 > 10).groupBy('col1', 'col2').count()
        df2 = df0.filter(df.col3 < 10).groupBy('col1', 'col2').count()
        return df1, df2

    def execute_and_save_analytics(self):
        output_df1, output_df2 = self._execute()
        output_df1.coalesce(1).write.csv('/path/file.csv', header='true')
        output_df2.coalesce(1).write.csv('/path/file.csv', header='true')

我怎样才能以这种方式重新组织代码,df0 只会被评估一次?我尝试在注释行中使用 persist() ,但没有任何性能改进。有什么想法吗?

另一个类似的问题,如果你没有一个 _execute(),但有许多类似的方法 _execute1()、_execute2() 等,你将如何组织你的管道。我想如果我分别调用 _execute() 方法,那么 PySpark 将分别评估每个转换管道(?),因此我失去了性能。

编辑:给定转换(过滤器、分组依据、计数)只是示例,我正在寻找使用任何类型的转换或 col3 定义的解决方案。

edit2:似乎在 init 中调用 cache() 是这里最好的优化。

标签: apache-sparkpysparkspark-dataframe

解决方案


照原样(persist注释掉)df0无论如何都会被评估两次。您的代码结构根本不会产生任何影响。

将您的代码拆分为

def _execute_1(self):
    df0 = self.df.withColumn('col3', df.col31 + df.col32)
    df1 = df0.filter(df.col3 > 10).groupBy('col1', 'col2').count()
    return df1

def _execute_2(self):
    df0 = self.df.withColumn('col3', df.col31 + df.col32)
    df2 = df0.filter(df.col3 < 10).groupBy('col1', 'col2').count()
    return df2

不会有任何区别。在不详细说明cache保证的情况下,您可以:

def __init__(self, df):
    self.df = df.withColumn('col3', df.col31 + df.col32).cache()

def _execute_1(self):
    return df0.filter(df.col3 > 10).groupBy('col1', 'col2').count()

def _execute_2(self):
    return df0.filter(df.col3 < 10).groupBy('col1', 'col2').count()

def execute_and_save_analytics(self):
    output_df1 = self._execute_1()
    output_df2 = self._execute_2()
    output_df1.coalesce(1).write.csv('/path/file1.csv', header='true')
    output_df2.coalesce(1).write.csv('/path/file2.csv', header='true')
    self.df.unpersist()

但它可能更容易:

(self.df
  .withColumn('col3', df.col31 + df.col32 > 10)
  .repartition("col3")
  .write.partitionBy("col3")
  .write.csv('/path/file.csv', header='true'))

推荐阅读