首页 > 解决方案 > 过滤火花中缺少高于阈值的列

问题描述

我想过滤所有缺失值高于 90 pct 的列,以便我可以将它们从分析中删除。

我尝试了以下代码,但它需要很长时间:

from pyspark.sql.functions import isnull, when, count, col
total_rows = df.count()
features_missing_above_90 = []
    
for feature in df.columns:
    
    feature_nulls = df.filter(df[feature].isNull()).count()
    result = (feature_nulls/total_rows)
    
    if(result>0.9):
        features_missing_above_90.append(df[feature].str)
    
print(features_missing_above_90)

谁能帮我?

谢谢!

标签: pythonapache-sparkpyspark

解决方案


Change df.count() to float value, without this your if condition will not execute as integer/integer will give you only integer.

>>> total_rows = float(df.count())
>>> features_missing_above_90 = []
>>> for feature in df.columns:
...     feature_nulls = df.filter(F.col(feature).isNull()).count()
...     result = feature_nulls/total_rows
...     if(result>0.7):
...             features_missing_above_90.append(feature)
...
>>> print(features_missing_above_90)
['b', 'e']

You can also try below code.

>>> df.show()
+----+----+----+----+----+----+
|   a|   b|   c|   d|   e|   f|
+----+----+----+----+----+----+
|   1|null|   2|   3|null|   4|
|   2|null|null|   5|null|   6|
|   4|   2|   1|   4|   5|   6|
|   5|null|null|null|null|   4|
|null|null|null|null|null|null|
+----+----+----+----+----+----+
>>> threshold = 0.7

Applying Logic

>>> columns = map(lambda c: F.when(((F.count(F.col("*")) - F.size(F.collect_set(F.col(c))))/F.count(F.col("*"))) > F.lit(threshold), F.lit(c)),df.columns)
>>> df \
... .select(F.array(*columns).alias("columns")) \
... .select(F.explode(F.col("columns")).alias("columns")) \
... .select(F.collect_set(F.col("columns")).alias("missing")) \
... .collect()[0]

Row(missing=[u'b', u'e'])

推荐阅读