首页 > 解决方案 > Spark StringIndexer.fit 在大型记录上非常慢

问题描述

我有格式化为以下示例的大型数据记录:

// +---+------+------+
// |cid|itemId|bought|
// +---+------+------+
// |abc|   123|  true|
// |abc|   345|  true|
// |abc|   567|  true|
// |def|   123|  true|
// |def|   345|  true|
// |def|   567|  true|
// |def|   789| false|
// +---+------+------+

cid并且itemId是字符串。

有 965,964,223 条记录。

我正在尝试使用如下方式转换cid为整数:StringIndexer

dataset.repartition(50)
val cidIndexer = new StringIndexer().setInputCol("cid").setOutputCol("cidIndex")
val cidIndexedMatrix = cidIndexer.fit(dataset).transform(dataset)

但是这些代码行非常慢(大约需要 30 分钟)。问题是它是如此之大,以至于在那之后我无法做任何进一步的事情。

我正在使用具有 2 个节点(61 GB 内存)的 R4 2XLarge 集群的亚马逊 EMR 集群。

我可以进一步提高性能吗?任何帮助都感激不尽。

标签: apache-sparkapache-spark-mlapache-spark-dataset

解决方案


如果列的基数很高,这是一种预期的行为。作为训练过程的一部分,StringIndexer收集所有标签,并创建标签-索引映射(使用 Spark 的o.a.s.util.collection.OpenHashMap)。

这个过程在最坏的情况下需要 O(N) 内存,并且是计算和内存密集型的。

如果列的基数很高,并且其内容将用作特征,则最好应用FeatureHasher(Spark 2.3 或更高版本)。

import org.apache.spark.ml.feature.FeatureHasher

val hasher = new FeatureHasher()
  .setInputCols("cid")
  .setOutputCols("cid_hash_vec")
hasher.transform(dataset)

它不保证唯一性并且不可逆,但对于许多应用来说已经足够了,并且不需要拟合过程。

对于不会用作特征的列,您还可以使用hash函数:

import org.apache.spark.sql.functions.hash

dataset.withColumn("cid_hash", hash($"cid"))

推荐阅读