首页 > 解决方案 > 使用pyspark在collect_set之后按值选择行

问题描述

使用

from pyspark.sql import functions as f

和方法f.aggf.collect_set我在 dataFrame 中创建了一个列 colSet 如下所示:

+-------+--------+
| index | colSet |
+-------+--------+
|      1|[11, 13]|
|      2|  [3, 6]|
|      3|  [3, 7]|
|      4|  [2, 7]|
|      5|  [2, 6]|
+-------+--------+

现在,如何使用 python/ 和 pyspark 仅选择那些行,例如,3 是 colSet 条目中数组的一个元素(通常,其中可能不止两个条目!)?

我试过使用这样的 udf 函数:

isInSet = f.udf( lambda vcol, val: val in vcol, BooleanType())

通过调用

dataFrame.where(isInSet(f.col('colSet'), 3))

我还尝试从调用者中删除 f.col 并在 isInSet 的定义中使用它,但都没有奏效,我遇到了一个异常:

AnalysisException: cannot resolve '3' given input columns: [index, colSet]

在给定具有 collect_set 结果的行的情况下,对于如何选择具有某个条目(甚至更好的子集!!!)的行的任何帮助表示赞赏。

标签: selectpysparkrow

解决方案


您的原始 UDF 很好,但要使用它,您需要将值 3 作为文字传递:

dataFrame.where(isInSet(f.col('colSet'), f.lit(3)))

但正如 jxc 在评论中指出的那样,使用array_contains可能是一个更好的选择:

dataFrame.where(f.array_contains(f.col('colSet'), 3))

我没有做过任何基准测试,但通常在 PySpark 中使用 UDF 比使用内置函数要慢,因为 JVM 和 Python 解释器之间的来回通信。


推荐阅读