首页 > 解决方案 > 在 Apache Spark 中有效地运行“for”循环,以便并行执行

问题描述

我们如何在 Spark 中并行化一个循环,以便处理不是顺序的并且是并行的。举个例子 - 我在一个包含以下数据的 csv 文件(称为“bill_item.csv”)中包含以下数据:

    |-----------+------------|
    | bill_id   | item_id    |
    |-----------+------------|
    | ABC       | 1          |
    | ABC       | 2          |
    | DEF       | 1          |
    | DEF       | 2          |
    | DEF       | 3          |
    | GHI       | 1          |
    |-----------+------------|

我必须得到如下输出:

    |-----------+-----------+--------------|
    | item_1    | item_2    | Num_of_bills |
    |-----------+-----------+--------------|
    | 1         | 2         | 2            |
    | 2         | 3         | 1            |
    | 1         | 3         | 1            |
    |-----------+-----------+--------------|

我们看到项目 1 和 2 已在 2 个账单“ABC”和“DEF”下找到,因此项目 1 和 2 的“Num_of_bills”为 2。类似地,项目 2 和 3 仅在账单“DEF”下找到,因此'Num_of_bills' 列是 '1' 等等。

我正在使用 spark 来处理 CSV 文件'bill_item.csv',并且我正在使用以下方法:

方法一:

from pyspark.sql.types import StructType, StructField, IntegerType, StringType

# define the schema for the data 
bi_schema = StructType([
    StructField("bill_id", StringType(), True), 
    StructField("item_id", IntegerType(), True) 
]) 

bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))

# find the list of all items in sorted order
item_list = bi_df.select("item_id").distinct().orderBy("item_id").collect()

item_list_len = len(item_list)
i = 0
# for each pair of items for e.g. (1,2), (1,3), (1,4), (1,5), (2,3), (2,4), (2,5), ...... (4,5)
while i < item_list_len - 1:
    # find the list of all bill IDs that contain item '1'
    bill_id_list1 = bi_df.filter(bi_df.item_id == item_list[i].item_id).select("bill_id").collect()
    j = i+1
    while j < item_list_len:
        # find the list of all bill IDs that contain item '2'
        bill_id_list2 = bi_df.filter(bi_df.item_id == item_list[j].item_id).select("bill_id").collect()

        # find the common bill IDs in list bill_id_list1 and bill_id_list2 and then the no. of common items
        common_elements = set(basket_id_list1).intersection(bill_id_list2)
        num_bils = len(common_elements)
        if(num_bils > 0):
            print(item_list[i].item_id, item_list[j].item_id, num_bils)
        j += 1    
    i+=1

但是,鉴于在现实生活中我们有数百万条记录,并且可能存在以下问题,这种方法并不是一种有效的方法:

  1. 可能没有足够的内存来加载所有项目或账单的列表
  2. 获得结果可能需要很长时间,因为执行是顺序的(感谢“for”循环)。(我用〜200000条记录运行上述算法,花了4个多小时才得出预期的结果。)

方法二:

我通过基于“item_id”拆分数据进一步优化了这一点,并使用以下代码块拆分数据:

bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
outputPath='/path/to/save'
bi_df.write.partitionBy("item_id").csv(outputPath)

拆分后,我执行了在“方法 1”中使用的相同算法,我发现在 200000 条记录的情况下,仍然需要 1.03 小时(与“方法 1”下的 4 小时相比有显着改进)才能获得最终输出。

上述瓶颈是因为顺序的“for”循环(也因为“collect()”方法)。所以我的问题是:

标签: pythonapache-sparkbigdataapache-spark-datasetapache-spark-2.0

解决方案


在火花中循环总是顺序的,在代码中使用它也不是一个好主意。根据您的代码,您一次使用while和读取单个记录,这将不允许 spark 并行运行。

如果您有大型数据集, Spark 代码应该在没有forwhile循环的情况下设计。

根据我对您的问题的理解,我已经在 scala 中编写了示例代码,它可以在不使用任何循环的情况下提供您想要的输出。请以下面的代码作为参考,并尝试以相同的方式设计代码。

注意:我已经用 Scala 编写了代码,这些代码也可以用相同的逻辑在 Python 中实现。

scala> import org.apache.spark.sql.expressions.UserDefinedFunction

scala> def sampleUDF:UserDefinedFunction = udf((flagCol:String) => {var out = ""
     |       val flagColList = flagCol.reverse.split(s""",""").map(x => x.trim).mkString(",").reverse.split(s",").toList
     |       var i = 0
     |     var ss = flagColList.size
     |     flagColList.foreach{ x =>
     |        i =  i + 1
     |      val xs = List(flagColList(i-1))
     |      val ys =  flagColList.slice(i, ss)
     |      for (x <- xs; y <- ys)  
     |           out = out +","+x + "~" + y
     |         }
     |             if(out == "") { out = flagCol}
     |    out.replaceFirst(s""",""","")})

//Input DataSet 
scala> df.show
+-------+-------+
|bill_id|item_id|
+-------+-------+
|    ABC|      1|
|    ABC|      2|
|    DEF|      1|
|    DEF|      2|
|    DEF|      3|
|    GHI|      1|
+-------+-------+

//Collectin all item_id corresponding to bill_id

scala> val df1 = df.groupBy("bill_id")
               .agg(concat_ws(",",collect_list(col("item_id"))).alias("item"))

scala> df1.show
+-------+-----+
|bill_id| item|
+-------+-----+
|    DEF|1,2,3|
|    GHI|    1|
|    ABC|  1,2|
+-------+-----+


//Generating combination of all item_id and filter out for correct data

scala>   val df2 = df1.withColumn("item", sampleUDF(col("item")))
                      .withColumn("item", explode(split(col("item"), ",")))
                      .withColumn("Item_1", split(col("item"), "~")(0))
                      .withColumn("Item_2", split(col("item"), "~")(1))
                      .groupBy(col("Item_1"),col("Item_2"))
                      .agg(count(lit(1)).alias("Num_of_bills"))
                      .filter(col("Item_2").isNotNull)

scala> df2.show
+------+------+------------+
|Item_1|Item_2|Num_of_bills|
+------+------+------------+
|     2|     3|           1|
|     1|     2|           2|
|     1|     3|           1|
+------+------+------------+

推荐阅读