首页 > 解决方案 > 遍历 pyspark 数据帧以查找分层血统

问题描述

我必须为组关系创建数据沿袭。所以在下面你可以看到源和目标。例如 1 与 A 有关系,A 也与 2 有关系,2 也有 C,1 也有 D。所以 1 与所有 a 有关系,结果是答案中的 ABCD。

我编写了代码,但这不是编写它的标准方式。我不能创建 DF 并进行迭代。

资源

id,group
1,A
2,A
1,B
3,B
2,C
3,C
1,D
4,D
5,D
6,E
7,E
8,F
6,F
9,G 

目标

+---+-----+----+
| id|group|  relation|
+---+-----+----+
|  7|    E|  EF|
|  3|    B|ABCD|
|  8|    F|  EF|
|  5|    D|ABCD|
|  6|    E|  EF|
|  9|    G|   G|
|  1|    A|ABCD|
|  4|    D|ABCD|
|  2|    A|ABCD|
+---+-----+----+

使用的代码

from pyspark.sql.functions import lit,Row
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql import Row
import sys
import datetime
import json
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.window import *
from pyspark.sql.functions import row_number
schema = StructType( [StructField("id", StringType(), True),StructField("group", StringType(), True)])
rdd = sc.parallelize( [Row("1", "A"), Row("2", "A"), Row("1", "B"), Row("3", "B"),Row("2", "C"), Row("3", "C"),Row("1", "D"), Row("4", "D"),Row("5", "D"), Row("6", "E"), Row("7", "E"), Row("8", "F"),Row("6","F"),Row("9","G")] )  
#create dummy df 
df = sqlContext.createDataFrame(rdd, schema)      
final_result = []     
# pick first row of each id
df4 = df.withColumn("row_num", row_number().over(Window.partitionBy("id").orderBy("group"))).where(F.col("row_num")==1).drop("row_num")
data = [(part["id"],part["group"]) for part in df4.collect()]   
# iterate on data to find relation  
for key,value in data:
    itid= [key]
    doneid= []
    itgroup= [value]
    donegroup= []
    while ((len(list(set(itid) - set(doneid)))>0) | (len(list(set(itgroup) - set(donegroup))))>0):
        for id in list(set(itid) - set(doneid)):
            itgroup.extend([part["group"]  for part in df.select("group").where(df["id"]==id).collect()])
            doneid.extend(id)
        for group in list(set(itgroup) - set(donegroup)):
            itid.extend([part["id"]  for part in df.select("id").where(df["group"]==group).collect()])
            donegroup.extend(group)
    res=''.join(sorted(donegroup))
    if len(res)>0:
        #append in final list
        final_result.append([key,value,res])

cSchema = StructType([StructField("id", StringType()),StructField("group", StringType()),StructField("id", StringType())])
result = spark.createDataFrame(final_result,schema=cSchema) 
# created final df and showed
result.show()

这项工作并给出正确的结果

+---+-----+----+
| id|group|  id|
+---+-----+----+
|  7|    E|  EF|
|  3|    B|ABCD|
|  8|    F|  EF|
|  5|    D|ABCD|
|  6|    E|  EF|
|  9|    G|   G|
|  1|    A|ABCD|
|  4|    D|ABCD|
|  2|    A|ABCD|
+---+-----+----+

我想知道编写此代码的最佳方法。我试过UDF,但那是说你不能在UDF中传递数据帧。

我不能创建 DF 并进行迭代。 获取数据df

def getdata():
    schema = StructType( [StructField("id", StringType(), True),StructField("group", StringType(), True)])
    rdd = sc.parallelize( [Row("1", "A"), Row("2", "A"), Row("1", "B"), Row("3", "B"),Row("2", "C"), Row("3", "C"),Row("1", "D"), Row("4","D"),Row("5", "D"), Row("6", "E"), Row("7", "E"), Row("8", "F"),Row("6","F"),Row("9","G")] )
    df = sqlContext.createDataFrame(rdd, schema)
    df.persist()
    return df

从 1 个输入行中提取关系字符串的函数,例如 (1,"A")

def getMember(key,value):
        df = getdata()
        itid= [key]
        doneid= []
        itgroup= [value]
        donegroup= []
        while ((len(list(set(itid) - set(doneid)))>0) | (len(list(set(itgroup) - set(donegroup))))>0):
            for id in list(set(itid) - set(doneid)):
                itgroup.extend([part["group"]  for part in df.select("group").where(df["id"]==id).collect()])
                doneid.extend(id)
            for group in list(set(itgroup) - set(donegroup)):
                itid.extend([part["id"]  for part in df.select("id").where(df["group"]==group).collect()])
                donegroup.extend(group)
        return ''.join(sorted(donegroup

))

    udf_getMember = udf( getMember, F.StringType())
    schema = StructType( [StructField("id", StringType(), True),StructField("group", StringType(), True)])
    rdd = sc.parallelize( [Row("1", "A"), Row("2", "A"), Row("1", "B"), Row("3", "B"),Row("2", "C"), Row("3", "C"),Row("1", "D"), Row("4", "D"),Row("5", "D"), Row("6", "E"), Row("7", "E"), Row("8", "F"),Row("6","F"),Row("9","G")] )
    df3 = sqlContext.createDataFrame(rdd, schema)    
    df4 = df3.withColumn("row_num", row_number().over(Window.partitionBy("id").orderBy("group"))).where(F.col("row_num")==1).drop("row_num")
   df4.withColumn('result', udf_getMember(F.col("id"),F.col("group")))

这不起作用并给我泡菜错误

标签: pysparkapache-spark-sql

解决方案


推荐阅读