首页 > 解决方案 > 从列有条件地计数

问题描述

我有一些数据存储在df1

| id1 | id2      | product |
|-----|----------|---------|
| 1   | abc-2323 | Upload  |
| 1   | 234234   | Upload  |
| 1   | 43322    | Upload  |
| 2   | abc-449  | Upload  |
| 3   | abc-495  | Upload  |
| 3   | 432      | Upload  |
| 3   | 6543     | Upload  |

id1and product,我想根据不同id2的 s 开始abc-与否来计算它们。我试过这样的条件聚合:

agg_data = (
    df1
    .groupby('id1', 'product')
    .agg(
       sf.when(~ sf.col('id2').like('abc-%'), sf.countDistinct('id2')).alias('id2_count_without_abc'),
       sf.when(sf.col('id2').like('abc-%'), sf.countDistinct('id2')).alias('is2_count_with_abc')
    )
)

然而,这个错误id2既不存在于 group by 中,也不是聚合函数。我不确定为什么我不能有条件地做这个,因为sf.countDistinct('id2')它本身就可以工作。例如,这没问题:

agg_data = (
    df1
    .groupby('id1', 'product')
    .agg(
       sf.countDistinct('id2').alias('id2_count_without_abc'),
       sf.countDistinct('id2').alias('is2_count_with_abc')
    )
)

标签: apache-sparkpysparkapache-spark-sql

解决方案


当你这样做时groupBy,每组会得到很多行。由于您可以为每组获得一个值countDistinct(),因此您不会出错。但是对于 when 子句,您一次可以传递一个值,但 group 将返回多行。

要消除此错误,您可以使用first()or选择一个值last()(这将返回相应组的第一个/最后一个值)并将其传递给when()函数。

但是您需要 group 的所有值来基于一个条件进行计数计算。下面的代码应该可以解决问题。

df1.groupby('id1', 'product') \
.agg(sf.collect_list('id2').alias('id2')) \
.withColumn('id2_count_with_abc', sf.size(sf.expr("filter(id2, i-> i like 'abc-%')"))) \
.withColumn('id2_count_without_abc', sf.abs(sf.size(sf.col('id2')) - sf.col('id2_count_with_abc'))) \
.drop('id2').show()


+---+-------+------------------+---------------------+
|id1|product|id2_count_with_abc|id2_count_without_abc|
+---+-------+------------------+---------------------+
|  2| Upload|                 1|                    0|
|  1| Upload|                 1|                    2|
|  3| Upload|                 1|                    2|
+---+-------+------------------+---------------------+

推荐阅读