首页 > 解决方案 > Pyspark NLP - CountVectorizer Max DF 或 TF。如何从数据集中过滤常见事件

问题描述

我正在使用CountVectorizer为 ML 准备数据集。我想过滤掉稀有词CountVectorizer,为此我使用了 minDF 或 minTF 的参数。我还想删除在我的数据集中“经常”出现的项目。我没有看到可以设置的 maxTF 或 maxDF 参数。有没有好的方法来做到这一点?

df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])

因此,在这种情况下,如果我想删除出现“4”次或 40% 时间的参数,以及出现 2 次或更少的参数。这将删除“b”和“c”。

目前我CountVectorizer(minDf=3......)为下限req运行。如何过滤掉出现频率高于我想要建模的项目。

标签: pythonapache-sparkpysparknlpapache-spark-ml

解决方案


我想您要求提供 CountVectorizer 参数,但到目前为止似乎还没有参数。这不是一种简单或实用的简单方法,但它确实有效。我希望这可以帮助你:

from pyspark.sql.types import *
from pyspark.sql import functions as F

df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])

counts_df = df \
    .select(F.explode('raw').alias('testCol')) \
    .groupby('testCol') \
    .agg(F.count('testCol').alias('count')).persist() # this will be used multiple times

total = counts_df \
    .agg(F.sum('count').alias('total')) \
    .rdd.take(1)[0]['total']
min_times = 3
max_times = total * 0.4
filtered_elements = counts_df \
    .filter((min_times>F.col('count')) | (F.col('count')>max_times)) \
    .select('testCol') \
    .rdd.map(lambda row: row['testCol']) \
    .collect()

def removeElements(arr):
    return list(set(arr) - set(filtered_elements))

remove_udf = F.udf(removeElements, ArrayType(StringType()))
filtered_df = df \
    .withColumn('raw', remove_udf('raw'))

结果:

filtered_df.show()
+-----+---+
|label|raw|
+-----+---+
|    0|[a]|
|    1|[a]|
+-----+---+

推荐阅读