首页 > 解决方案 > 在 Spark 数据框的列中为每个组添加递增的数字

问题描述

我有一个包含 2 列“Id”和“category”的数据框。对于每个类别,我想标记编码列“Id”,所以预期的结果将是这样的列“Enc_id”

Id   Category    Enc_id
a1       A         0
a2       A         1
b1       B         0 
c1       C         0
c2       C         1
a3       A         2
b2       B         1
b3       B         2 
b4       B         3 
b4       B         3
b3       B         2

在这里,Id 可能不是唯一的,因此可能存在重复的行。我想创建一个窗口partitionBy(category),然后在这个窗口上应用标签编码(StringIndexer),但它没有用。请问有什么提示吗?

标签: scaladataframeapache-sparkapache-spark-sql

解决方案


您可以将window函数与substring函数一起使用并计算rank

val window = Window.partitionBy($"Category", substring($"Id", 1,1)).orderBy("Id")

df.withColumn("Enc_id", rank().over(window) - 1) // -1 to start the rank from 0
  .show(false)

输出:

+---+--------+------+
|Id |Category|Enc_id|
+---+--------+------+
|a1 |A       |0     |
|a2 |A       |1     |
|a3 |A       |2     |
|c1 |C       |0     |
|c2 |C       |1     |
|b1 |B       |0     |
|b2 |B       |1     |
|b3 |B       |2     |
|b4 |B       |3     |
+---+--------+------+

Update1: ​​对于更新后的具有重复 id 的案例

df1.groupBy("Id", "Category")
  .agg(collect_list("Category") as "list_category")
  .withColumn("Enc_id", rank().over(window) - 1)
  .withColumn("Category", explode($"list_category"))
  .drop("list_category")
  .show(false)

输出:

+---+--------+------+
|Id |Category|Enc_id|
+---+--------+------+
|a1 |A       |0     |
|a2 |A       |1     |
|a3 |A       |2     |
|c1 |C       |0     |
|c2 |C       |1     |
|b1 |B       |0     |
|b2 |B       |1     |
|b3 |B       |2     |
|b3 |B       |2     |
|b4 |B       |3     |
|b4 |B       |3     |
+---+--------+------+

推荐阅读