首页 > 解决方案 > PySpark Iterator[pd.Series] UDF 中的输入列如何拆分成块?

问题描述

PySpark 3.0 引入了矢量化 Pandas UDF。其中一个示例显示了一个 UDF 注释为Iterator[pd.Sereis] -> Iterator[pd.Series]

from typing import Iterator

import pandas as pd

from pyspark.sql.functions import pandas_udf

pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)

# Declare the function and create the UDF
@pandas_udf("long")
def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    for x in iterator:
        yield x + 1

df.select(plus_one("x")).show()
# +-----------+
# |plus_one(x)|
# +-----------+
# |          2|
# |          3|
# |          4|
# +-----------+

我的理解是输入列[1,2,3]被拆分成块(比如说[1], [2], [3]然后连接回列。

x是什么决定了每次迭代中大小的大小?是否可配置?如果您不需要同时加载整个列,这种模式是否仅用于节省内存?

标签: pysparkapache-spark-sql

解决方案


推荐阅读