首页 > 解决方案 > 如何将变量传递给 UDAF(用户定义聚合函数)

问题描述

import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import *
import os



@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def split(df, validation_period):

   ""Logic""

    return df

def train_test_split(spark, data_frame, request_json_data):

    data_frame = spark.createDataFrame(data_frame)
    print(data_frame.schema)
 

    validation_period = request_json_data['validation_period']
    groupby_key = request_json_data['groupby_key']

    data_frame.groupby(groupby_key).apply(split, validation_period).show()

不能调用 split 函数它给出错误。apply() 接受 2 个位置参数,但给出了 3 个。我想将 validation_period 作为参数传递给 split 函数。

标签: pysparkapache-spark-sql

解决方案


简短的回答:您不能将额外的参数传递给 pandas grouped map udf,因为它只将单个 pandas df 作为参数除外。

长答案:还有其他方法可以将 validation_period 传递给函数

  1. 使用某种形式的闭包

    def split_fabric(validation_period):
        @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
        def split(df):
    
            ""Logic""
    
            return df
    
  2. 将其作为列传递

    data_frame \
        .withColumn("validation_period", F.lit(validation_period)) \
        .groupby(groupby_key).apply(split, validation_period).show()
    

推荐阅读