首页 > 解决方案 > 如何在 PySpark 中找出组中列的唯一元素数?

问题描述

我有一个 PySpark 数据框-

df1 = spark.createDataFrame([
    ("u1", 1),
    ("u1", 2),
    ("u2", 1),
    ("u2", 1),
    ("u2", 1),
    ("u3", 3),
    ],
    ['user_id', 'var1'])

print(df1.printSchema())
df1.show(truncate=False)

输出-

root
 |-- user_id: string (nullable = true)
 |-- var1: long (nullable = true)

None
+-------+----+
|user_id|var1|
+-------+----+
|u1     |1   |
|u1     |2   |
|u2     |1   |
|u2     |1   |
|u2     |1   |
|u3     |3   |
+-------+----+

现在我想对所有唯一用户进行分组,并在新列中显示他们的唯一 var 数量。所需的输出看起来像 -

+-------+---------------+
|user_id|num_unique_var1|
+-------+---------------+
|u1     |2              |
|u2     |1              |
|u3     |1              |
+-------+---------------+

我可以使用 collect_set 并制作一个 udf 来查找集合的长度。但我认为必须有更好的方法来做到这一点。如何在一行代码中实现这一目标?

标签: apache-sparkpysparkapache-spark-sql

解决方案


df1.groupBy('user_id').agg(F.countDistinct('var1').alias('num')).show()

countDistinct正是我所需要的。

输出-

+-------+---+
|user_id|num|
+-------+---+
|     u3|  1|
|     u2|  1|
|     u1|  2|
+-------+---+

推荐阅读