首页 > 解决方案 > PySpark:如何使用`StringIndexer`对字符串数组列进行标签编码

问题描述

众所周知,我们可以在字符串列上做LabelEncoder()StringIndexer但是如果要LabelEncoder()在字符串数组列上做,实现起来并不容易。

# input
df.show()

+--------------------------------------+
|                                  tags|
+--------------------------------------+
|        [industry, display, Merchants]|
|    [smart, swallow, game, Experience]|
|             [social, picture, social]|
|        [default, game, us, adventure]|
| [financial management, loan, product]|
|       [system, profile, optimization]|

...
# After do LabelEncoder() on `tags` column 
...

+--------------------------------------+
|                                  tags|
+--------------------------------------+
|                             [0, 1, 2]|
|                          [3, 4, 4, 5]|
|                             [6, 7, 6]|
|                         [8, 4, 9, 10]|
|                          [11, 12, 13]|
|                          [14, 15, 16]|

标签: dataframeapache-sparkpysparkapache-spark-sql

解决方案


Python 版本将非常相似:

// add unique id to each row
val df2 = df.withColumn("id", monotonically_increasing_id).select('id, explode('tags).as("tag"))

val indexer = new StringIndexer()
  .setInputCol("tag")
  .setOutputCol("tagIndex")

val indexed = indexer.fit(df2).transform(df2)

// in the final step you should convert tags back to array of tags
val dfFinal = indexed.groupBy('id).agg(collect_list('tagIndex))

推荐阅读