首页 > 解决方案 > pandas udf拆分字符串数组pyspark

问题描述

我有下表

id | country_mapping
--------------------
1  | {"GBR/bla": 1,
      "USA/bla": 2}

我想创建一个包含以下内容的列

id | source_countries
--------------------
1  | ["GBR", "USA"]

我需要通过 pandas udf 来完成。我创建了以下

import pyspark.sql.functions as F

@F.pandas_udf("string")
def func(s):
    return s.apply(lambda x: [y.split("/")[0] for y in x])

我认为这会起作用,因为如果我在纯熊猫中运行此代码,它会提供我需要的东西。

import pandas as pd
s = pd.Series([["GBR/1", "USA/2"], ["ITA/1", "FRA/2"]])
s.apply(lambda x: [y.split("/")[0] for y in x])

Out[1]: 0    [GBR, USA]
        1    [ITA, FRA]
dtype: object

但是当我跑步时

df.withColumn('source_countries', 
              func(F.map_keys(F.col("country_mapping")))).collect()

当我运行以下命令时,它失败并出现以下错误:

PythonException: An exception was thrown from a UDF: 'pyarrow.lib.ArrowTypeError: Expected bytes, got a 'list' object'

我对为什么 - 以及如何修复我的 pandas udf 感到困惑。

标签: pandasapache-sparkpysparkapache-spark-sqluser-defined-functions

解决方案


而不是pandas_udf,您可以udf以类似的方式使用

from pyspark.sql import functions as F
from pyspark.sql import types as T

def func(v):
    return [x.split('/')[0] for x in v]

(df
     .withColumn('source_countries', F.udf(func, T.ArrayType(T.StringType()))(F.map_keys(F.col('country_mapping'))))
     .show(10, False)
)

# +---+----------------------------+----------------+
# |id |country_mapping             |source_countries|
# +---+----------------------------+----------------+
# |1  |{USA/bla -> 2, GBR/bla -> 1}|[USA, GBR]      |
# +---+----------------------------+----------------+

推荐阅读