首页 > 解决方案 > 如何通过继承向 Pyspark Dataframe 类添加自定义方法

问题描述

我正在尝试继承 DataFrame 类并添加如下额外的自定义方法,以便我可以流畅地链接并确保所有方法都引用相同的数据帧。我得到一个异常,因为列不可迭代

from pyspark.sql.dataframe import DataFrame

class Myclass(DataFrame):
def __init__(self,df):
    super().__init__(df._jdf, df.sql_ctx)

def add_column3(self):
 // Add column1 to dataframe received
  self._jdf.withColumn("col3",lit(3))
  return self

def add_column4(self):
 // Add column to dataframe received
  self._jdf.withColumn("col4",lit(4))
  return self

if __name__ == "__main__":
'''
Spark Context initialization code
col1 col2
a 1
b 2
'''
  df = spark.createDataFrame([("a",1), ("b",2)], ["col1","col2"])
  myobj = MyClass(df)
  ## Trying to accomplish below where i can chain MyClass methods & Dataframe methods
  myobj.add_column3().add_column4().drop_columns(["col1"])

'''
Expected Output
col2, col3,col4
1,3,4
2,3,4
'''

标签: pythonapache-sparkpyspark

解决方案


以下是我的解决方案(基于您的代码)。我不知道这是否是最佳做法,但至少可以正确执行您想要的操作。Dataframes 是不可变的对象,所以在我们添加一个新列之后,我们创建一个新对象但不是一个Dataframe对象而是一个Myclass对象,因为我们想要有 Dataframe 和自定义方法。

from pyspark.sql.dataframe import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()


class MyClass(DataFrame):
   def __init__(self,df):
      super().__init__(df._jdf, df.sql_ctx)
      self._df = df

  def add_column3(self):
      #Add column1 to dataframe received
      newDf=self._df.withColumn("col3",F.lit(3))
      return MyClass(newDf)

  def add_column4(self):
      #Add column2 to dataframe received
      newDf=self._df.withColumn("col4",F.lit(4))
      return MyClass(newDf)


df = spark.createDataFrame([("a",1), ("b",2)], ["col1","col2"])
myobj = MyClass(df)
myobj.add_column3().add_column4().na.drop().show()

# Result:
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|   a|   1|   3|   4|
|   b|   2|   3|   4|
+----+----+----+----+

推荐阅读