首页 > 解决方案 > 如何在两个不同列表中包含的基于值的数据框中创建新列?

问题描述

我有一个这样的 pyspark 数据框:

+--------------------+--------------------+
|               label|           sentences|
+--------------------+--------------------+
|[things, we, eati...|<p>I am construct...|
|[elephants, nordi...|<p><strong>Edited...|
|[bee, cross-entro...|<p>I have a data ...|
|[milking, markers...|<p>There is an Ma...|
|[elephants, tease...|<p>I have Score d...|
|[references, gene...|<p>I'm looking fo...|
|[machines, exitin...|<p>I applied SVM ...|
+--------------------+--------------------+

和这样的top_ten列表:

['bee', 'references', 'milking', 'expert', 'bombardier', 'borscht', 'distributions', 'wires', 'keyboard', 'correlation']

而且我需要创建一个new_label列,指示列表1.0中是否存在至少一个标签值top_ten(当然,对于每一行)。

虽然逻辑是有道理的,但我在语法方面的经验不足。这个问题肯定有一个简短的答案吗?

我试过了:

temp = train_df.withColumn('label', F.when(lambda x: x.isin(top_ten), 1.0).otherwise(0.0))

和这个:

def matching_top_ten(top_ten, labels):
    for label in labels:
        if label.isin(top_ten):
            return 1.0
        else:
            return 0.0

在最后一次尝试之后,我发现这些函数无法映射到数据帧。所以我想我可以将列转换为 RDD,映射它,然后再.join()返回,但这听起来不必要的乏味。

**更新:**尝试将上述函数作为 UDF 也没有运气......

from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
matching_udf = udf(matching_top_ten, FloatType())
temp = train_df.select('label', matching_udf(top_ten, 'label').alias('new_labels'))
----
TypeError: Invalid argument, not a string or column: [...top_ten list values...] of type <class 'list'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

我在 SO 上发现了其他类似的问题,但是,它们都不涉及根据另一个列表验证列表的逻辑(充其量是针对列表的单个值)。

标签: python-3.xpysparkapache-spark-sqlpyspark-sqlpyspark-dataframes

解决方案


您不需要使用 audf并且可以避免explode+的费用agg

火花版本 2.4+

您可以使用pyspark.sql.functions.arrays_overlap

import pyspark.sql.functions as F

top_ten_array = F.array(*[F.lit(val) for val in top_ten])

temp = train_df.withColumn(
    'new_label', 
    F.when(F.arrays_overlap('label', top_ten_array), 1.0).otherwise(0.0)
)

或者,您应该能够使用pyspark.sql.functions.array_intersect().

temp = train_df.withColumn(
    'new_label', 
    F.when(
        F.size(F.array_intersect('label', top_ten_array)) > 0, 1.0
    ).otherwise(0.0)
)

label这两个都检查和的交集的大小top_ten是否非零。


对于 Spark 1.5 到 2.3,您可以array_contains在循环中使用top_ten

from operator import or_
from functools import reduce

temp = train_df.withColumn(
    'new_label',
    F.when(
        reduce(or_, [F.array_contains('label', val) for val in top_ten]),
        1.0
    ).otherwise(0.0)
)

您测试以查看是否label包含 中的任何值top_ten,并使用按位或减少结果。True仅当 中的任何值top_ten包含在中时才会返回label


推荐阅读