首页 > 解决方案 > 在 PySpark 中使用 pandas_udf 时无法填充数组

问题描述

我有一个 PySpark 数据框,就像

+---+------+------+
|key|value1|value2|
+---+------+------+
|  a|     1|     0|
|  a|     1|    42|
|  b|     3|    -1|
|  b|    10|    -2|
+---+------+------+

我已经定义了一个 pandas_udf 像 -

schema = StructType([
    StructField("key", StringType())
])

arr = []
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def g(df):
    k = df.key.iloc[0]
    series = [d for d in df.value2]
    arr.append(len(series))
    print(series)
    return pd.DataFrame([k])
df3.groupby("key").apply(g).collect()
print(arr)

很明显,数组 arr 应该是 [2, 2],但它仍然是空的。当我检查驱动程序日志时, print(series) 的输出看起来是正确的,但数组仍然是空的。

返回类型对我来说并不重要,因为我没有更改/处理数据,我只想将它推送到自定义类对象中。

标签: pandasapache-sparkpysparkpandas-groupbyuser-defined-functions

解决方案


我必须为列表定义一个自定义累加器并使用它。

from pyspark.accumulators import AccumulatorParam
class ListParam(AccumulatorParam):
    def zero(self, val):
        return []
    def addInPlace(self, val1, val2):
        val1.append(val2)
        return val1

推荐阅读