首页 > 解决方案 > PySpark如何高效处理嵌套数据?

问题描述

我这里有个情况,发现spark中collect_list在item已经是list的时候效率不高。

基本上,我试图计算嵌套列表的平均值(每个列表的大小保证相同)。例如,当数据集变为 10 M 行时,可能会产生内存不足错误。最初,我认为它与 udf (计算平均值)有关。但实际上,我发现聚合部分(collect_list of lists)才是真正的问题。

我现在正在做的是将 10 M 行分成多个块(按“用户”),单独聚合每个块,然后在最后合并它们。关于有效处理嵌套数据有什么更好的建议吗?

例如,玩具示例是这样的:

data = [('user1','place1', ['place1', 'place2', 'place3'], [0.0, 0.5, 0.4], [0.0, 0.4, 0.3]),
    ('user1','place2', ['place1', 'place2', 'place3'], [0.7, 0.0, 0.4], [0.6, 0.0, 0.3]),
    ('user2','place1', ['place1', 'place2', 'place3'], [0.0, 0.4, 0.3], [0.0, 0.3, 0.4]),
    ('user2','place3', ['place1', 'place2', 'place3'], [0.1, 0.2, 0.0], [0.3, 0.1, 0.0]),
    ('user3','place2', ['place1', 'place2', 'place3'], [0.3, 0.0, 0.4], [0.2, 0.0, 0.4]),
   ]
data_df = sparkApp.sparkSession.createDataFrame(data, ['user', 'place', 'places', 'data1', 'data2'])

data_agg = data_df.groupBy('user') \
    .agg(f.collect_list('place').alias('place_list'),
         f.first('places').alias('places'),
         f.collect_list('data1').alias('data1'),
         f.collect_list('data1').alias('data2'),
        )

import numpy as np
def average_values(sim_vectors):
    if len(sim_vectors) == 1:
        return sim_vectors[0]
    mat = np.array(sim_vectors)
    mean_vector = np.mean(mat, axis=0)
    return np.round(mean_vector, 3).tolist()

avg_vectors_udf = f.udf(average_values, ArrayType(DoubleType()))
data_agg_ave = data_agg.withColumn('data1', avg_vectors_udf('data1')) \
    .withColumn('data2', avg_vectors_udf('data2'))

结果将是:

+-----+----------------+--------------------+-----------------+-----------------+

| user|      place_list|              places|            data1|            data2|

+-----+----------------+--------------------+-----------------+-----------------+

|user1|[place1, place2]|[place1, place2, ...|[0.35, 0.25, 0.4]|[0.35, 0.25, 0.4]|

|user3|        [place2]|[place1, place2, ...|  [0.3, 0.0, 0.4]|  [0.3, 0.0, 0.4]|

|user2|[place1, place3]|[place1, place2, ...|[0.05, 0.3, 0.15]|[0.05, 0.3, 0.15]|

+-----+----------------+--------------------+-----------------+-----------------+

标签: python-3.xpysparkpyspark-sql

解决方案


推荐阅读