首页 > 解决方案 > 获取数组列中至少有一个相同值的所有行

问题描述

我知道这个线程的存在:
Spark get all rows with the same values in array in column
I've tried but I can't come up with a way to write the accepted answer:

  df.join(
  df.withColumnRenamed("id", "id2").withColumnRenamed("hashes", "hashes2"),
  exists(arrays_zip(col("hashes"), col("hashes2")), x => x("hashes") === x("hashes2"))
)
  .groupBy("id")
  .agg(first(col("hashes")).as("hashes"), collect_list("id2").as("matched"))
  .withColumn("matched", filter(col("matched"), x => x.notEqual(col("id"))))  

在使用Spark 2.4.7的pyspark中,该函数不存在。pyspark.sql.functions.exists

与另一个线程中的请求的主要区别是我不需要数组中的元素位于相同的位置,所以给出:

+---+-------------------------+
|id |hashes                   |
+---+-------------------------+
|0  |["1", "2", "3", "4", "5"]|
|1  |["1", "5", "3", "7", "9"]|
|2  |["9", "7", "6", "8", "0"]|
+---+-------------------------+

结果如下所示:

+---+-------------------------+-----------+
|id |hashes                   |matches    |
+---+-------------------------+-----------+
|0  |["1", "2", "3", "4", "5"]|["1"]      |
|1  |["1", "5", "3", "7", "9"]|["0","2"]  |
|2  |["9", "7", "6", "8", "0"]|["1"]      |
+---+-------------------------+-----------+

请注意,我的数组元素类型是字符串。

你能帮我解决这个问题吗?还有另一种方法可以更有效地实现该线程中所要求的内容吗?非常感谢你

标签: pythondataframeapache-sparkpyspark

解决方案


根据我的理解,您可能需要一个交叉连接arrays_overlap来检查来自其他 id 的值是否与现有的重叠,然后过滤返回 true 的行,然后分组:

from pyspark.sql import functions as F

out = (df.crossJoin(df.select(F.col("id").alias("id1"),
                              F.col("hashes").alias("hashes1")))
      .where("id != id1")
      .withColumn("Match",F.arrays_overlap("hashes","hashes1")).filter("Match")
      .groupBy("id").agg(F.first("hashes").alias("hashes"),
               F.collect_list(F.col("id1").cast("String")).alias("Matches")))

out.show(truncate=False)

+---+---------------+-------+
|id |hashes         |Matches|
+---+---------------+-------+
|0  |[1, 2, 3, 4, 5]|[1]    |
|1  |[1, 5, 3, 7, 9]|[0, 2] |
|2  |[9, 7, 6, 8, 0]|[1]    |
+---+---------------+-------+

测试输出的架构:

out.printSchema()

root
 |-- id: long (nullable = true)
 |-- hashes: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- Matches: array (nullable = false)
 |    |-- element: string (containsNull = false)

推荐阅读