首页 > 解决方案 > 使用pyspark时如何在agg和groupBy中使用lambda?

问题描述

我只是在学习 pyspark。我对以下代码感到困惑:

df.groupBy(['Category','Register']).agg({'NetValue':'sum',
                                     'Units':'mean'}).show(5,truncate=False)

df.groupBy(['Category','Register']).agg({'NetValue':'sum',
                                     'Units': lambda x: pd.Series(x).nunique()}).show(5,truncate=False)

第一行是正确的。但是第二行是不正确的。错误信息是:

AttributeError: 'function' object has no attribute '_get_object_id'

看来我没有正确使用 lambda 函数。但这就是我在普通 python 环境中使用 lambda 的方式,而且是正确的。

有人可以在这里帮助我吗?

标签: pythonlambdapysparkaggregation

解决方案


如果您对使用纯 Python 函数的 PySpark 原语的性能感到满意,以下代码将给出所需的结果。您可以修改逻辑_map以满足您的特定需求。我对您的数据模式的外观做了一些假设。

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, LongType

schema = StructType([
    StructField('Category', StringType(), True),
    StructField('Register', LongType(), True),
    StructField('NetValue', LongType(), True),
    StructField('Units', LongType(), True)
])

test_records = [
    {'Category': 'foo', 'Register': 1, 'NetValue': 1, 'Units': 1},
    {'Category': 'foo', 'Register': 1, 'NetValue': 2, 'Units': 2},
    {'Category': 'foo', 'Register': 2, 'NetValue': 3, 'Units': 3},
    {'Category': 'foo', 'Register': 2, 'NetValue': 4, 'Units': 4},
    {'Category': 'bar', 'Register': 1, 'NetValue': 5, 'Units': 5}, 
    {'Category': 'bar', 'Register': 1, 'NetValue': 6, 'Units': 6}, 
    {'Category': 'bar', 'Register': 2, 'NetValue': 7, 'Units': 7},
    {'Category': 'bar', 'Register': 2, 'NetValue': 8, 'Units': 8}
]

spark = SparkSession.builder.getOrCreate()
dataframe = spark.createDataFrame(test_records, schema)

def _map(((category, register), records)):
    net_value_sum = 0
    uniques = set()
    for record in records:
        net_value_sum += record['NetValue']
        uniques.add(record['Units'])
    return category, register, net_value_sum, len(uniques)

new_dataframe = spark.createDataFrame(
    dataframe.rdd.groupBy(lambda x: (x['Category'], x['Register'])).map(_map),
    schema
)
new_dataframe.show()

结果:

+--------+--------+--------+-----+
|Category|Register|NetValue|Units|
+--------+--------+--------+-----+
|     bar|       2|      15|    2|
|     foo|       1|       3|    2|
|     foo|       2|       7|    2|
|     bar|       1|      11|    2|
+--------+--------+--------+-----+

如果您需要性能或坚持使用 pyspark.sql 框架,请查看此相关问题及其链接问题:

PySpark 数据帧上的自定义聚合


推荐阅读