首页 > 解决方案 > 如何获取 Spark DataFrame 中每行列表中最大值的索引?[PySpark]

问题描述

我已经完成了 LDA 主题建模并将其存储在lda_model.

转换我的原始输入数据集后,我检索了一个 DataFrame。其中一列是 topicDistribution,其中该行属于 LDA 模型中每个主题的概率。因此,我想获取每行列表中最大值的索引。

df -- | 'list_of_words' | 'index ' | 'topicDistribution' | 
       ['product','...']     0       [0.08,0.2,0.4,0.0001]
          .....             ...         ........

我想转换 df 以便添加一个附加列,它是每行 topicDistribution 列表的 argmax。

df_transformed --  | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |
                    ['product','...']     0     [0.08,0.2,0.4,0.0001]      2
                       ......            ....         .....              ....

我该怎么做?

标签: pythonapache-sparkpysparkrdd

解决方案


您可以创建一个用户定义的函数来获取最大值的索引

from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType

max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))

例子

>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType 
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
|  [0.2, 0.3, 0.5]|
+-----------------+

>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
|  [0.2, 0.3, 0.5]|      2|
+-----------------+-------+

编辑:

由于您提到其中的列表topicDistribution是 numpy 数组,因此您可以更新max_index udf如下:

max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())

推荐阅读