首页 > 解决方案 > 寻找 pyspark 的 arrays_zip 的逆

问题描述

我有以下格式讨厌的输入数据框:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.master("local").getOrCreate()

input_df = spark.createDataFrame(
    [
        ('Alice;Bob;Carol',),
        ('12;13;14',),
        ('5;;7',),
        ('1;;3',),
        (';;3',)
    ],
    ['data']
)
  
input_df.show()

# +---------------+
# |           data|
# +---------------+
# |Alice;Bob;Carol|
# |       12;13;14|
# |           5;;7|
# |           1;;3|
# |            ;;3|
# +---------------+

实际输入是一个分号分隔的 CSV 文件,其中一列包含一个人的值。每个人可以有不同数量的值。在这里,Alice 有 3 个值,Bob 只有一个,Carol 有 4 个值。

我想在 PySpark 中将其转换为一个输出数据框,该数据框为每个人保存一个数组,在此示例中,输出将是:

result = spark.createDataFrame(
    [
        ("Alice", [12, 5, 1]),
        ("Bob", [13,]),
        ("Carol", [14, 7, 3, 3])
    ],
    ['name', 'values']
)

result.show()

# +-----+-------------+
# | name|       values|
# +-----+-------------+
# |Alice|   [12, 5, 1]|
# |  Bob|         [13]|
# |Carol|[14, 7, 3, 3]|
# +-----+-------------+

我该怎么做?我认为这将是F.arrays_zip(),F.split()和/或的某种组合F.explode(),但我无法弄清楚。

我目前被困在这里,这是我现在的尝试:

(input_df
    .withColumn('splits', F.split(F.col('data'), ';'))
    .drop('data')
).show()

# +-------------------+
# |             splits|
# +-------------------+
# |[Alice, Bob, Carol]|
# |       [12, 13, 14]|
# |           [5, , 7]|
# |           [1, , 3]|
# |            [, , 3]|
# +-------------------+

标签: pythonapache-sparkpysparkpyspark-dataframes

解决方案


一种方法可以是读取第一行作为标题,然后取消透视数据

df1 = spark.createDataFrame([(12,13,14),(5,None,7),(1,None,3),(None,None,3)], ['Alice','Bob','Carol'])

df1.show()
+-----+----+-----+
|Alice| Bob|Carol|
+-----+----+-----+
|   12|  13|   14|
|    5|null|    7|
|    1|null|    3|
| null|null|    3|
+-----+----+-----+

df1.select(f.expr('''stack(3,'Alice',Alice,'Bob',Bob,'Carol',Carol) as (Name,Value)'''))\
   .groupBy('Name').agg(f.collect_list('value').alias('Value')).orderBy('Name').show()

+-----+-------------+
| Name|        Value|
+-----+-------------+
|Alice|   [12, 5, 1]|
|  Bob|         [13]|
|Carol|[14, 7, 3, 3]|
+-----+-------------+

对于动态传递列,请使用以下代码

cols = ','.join([f"'{i[0]}',{i[1]}" for i in zip(df1.columns,df1.columns)])
df1.select(f.expr(f'''stack(3,{cols}) as (Name,Value)''')).groupBy('Name').agg(f.collect_list('value').alias('Value')).orderBy('Name').show()

+-----+-------------+
| Name|        Value|
+-----+-------------+
|Alice|   [12, 5, 1]|
|  Bob|         [13]|
|Carol|[14, 7, 3, 3]|
+-----+-------------+

推荐阅读