首页 > 解决方案 > Spark - 更改数据集中属于长尾的记录的值

问题描述

我正在尝试解决机器学习问题中的数据清理步骤,我应该将长尾中的所有元素分组到一个名为“其他”的通用类别中。例如,我有一个这样的数据框:

val df = sc.parallelize(Seq(
(1, "ABC"),
(2, "ABC"),
(3, "123"),
(4, "FPK"),
(5, "FPK"),
(6, "ABC"),
(7, "ABC"),
(8, "980"),
(9, "abc"),
(10, "FPK")
)).toDF("n", "s")

我想保留类别"ABC""FPK"因为它们出现了好几次,但我不想有一个不同的类别:123,980,abc因为它们只出现一次。所以我想要的是:

+---+------+
|  n|     s|
+---+------+
|  1|   ABC|
|  2|   ABC|
|  3|Others|
|  4|   FPK|
|  5|   FPK|
|  6|   ABC|
|  7|   ABC|
|  8|Others|
|  9|Others|
| 10|   FPK|
+---+------+

为了实现这一点,我尝试的是:

val newDF = df.withColumn("s",when($"s".isin("123","980","abc"),"Others").otherwise('s)

这工作正常。

但我想以编程方式决定哪些类别属于长尾,在我的情况下,在 originall 数据框中只出现一次。所以我写了这个来创建一个数据框,其中包含那些只出现一次的类别:

val longTail = df.groupBy("s").agg(count("*").alias("cnt")).orderBy($"cnt".desc).filter($"cnt"<2)

+---+---+
|  s|cnt|
+---+---+
|980|  1|
|abc|  1|
|123|  1|
+---+---+

现在我试图将这个 longTail 数据集中的“s”列的值转换为一个列表,以便用我之前硬编码的那个来交换它。所以我尝试了:

 val ar = longTail.select("s").collect().map(_(0)).toList

ar: List[Any] = List(123, 980, abc)

但是当我尝试添加 ar

val newDF = df.withColumn("s",when($"s".isin(ar),"Others").otherwise('s))

我收到以下错误:

java.lang.RuntimeException:不支持的文字类型类 scala.collection.immutable.$colon$colon List(123, 980, abc)

我错过了什么?

标签: scalaapache-sparkmachine-learning

解决方案


这是正确的语法:

scala> df.withColumn("s", when($"s".isin(ar : _*), "Others").otherwise('s)).show
+---+------+
|  n|     s|
+---+------+
|  1|   ABC|
|  2|   ABC|
|  3|Others|
|  4|   FPK|
|  5|   FPK|
|  6|   ABC|
|  7|   ABC|
|  8|Others|
|  9|Others|
| 10|   FPK|
+---+------+

这称为重复参数。参考这里


推荐阅读