首页 > 解决方案 > 如何查找数组的任何元素是否在pyspark的范围内

问题描述

我在数据框中有一列数组,我想知道数组的任何元素是否在一定范围内。示例:
输入:

+------------------------------------------------------------------------------------------+
|dateTimeValue                                                                             |
+------------------------------------------------------------------------------------------+
|[2019-11-11T20:08:47.453+0000, 2020-10-15T20:08:47.453+0000, 2021-09-19T20:08:47.453+0000]|
|[2017-11-05T20:08:47.453+0000, 2020-05-05T20:08:47.453+0000, 2021-11-11T20:08:47.453+0000]|
+------------------------------------------------------------------------------------------+

感兴趣的日期范围是 2018 年 8 月 8 日和 2019 年 12 月 8 日。
输出:

+------------------------------------------------------------------------------------------+------------+
|dateTimeValue                                                                             |includedFlag|
+------------------------------------------------------------------------------------------+------------+
|[2019-11-11T20:08:47.453+0000, 2020-10-15T20:08:47.453+0000, 2021-09-19T20:08:47.453+0000]|True        |
|[2017-11-05T20:08:47.453+0000, 2020-05-05T20:08:47.453+0000, 2021-11-11T20:08:47.453+0000]|False       |
+------------------------------------------------------------------------------------------+------------+

我的数据框的架构是:

root 
|-- dateTimeValue: array (nullable = true) | 
    |-- element: timestamp (containsNull = true)

输入可以通过以下方式生成:

import datetime 
df = spark.createDataFrame([([datetime.datetime(2019,11,11,20,8,47), datetime.datetime(2020,10,15,20,8,47), datetime.datetime(2021,9,19,20,8,47)],), ([datetime.datetime(2017,11,5,20,8,47), datetime.datetime(2020,5,5,20,8,47), datetime.datetime(2021,11,11,20,8,47)],)], ['dateTimeValue'])

谢谢。

标签: pythonarraysapache-sparkpyspark

解决方案


With explode you can create a new row for each element of your array and compare each single element with the boundaries of your range.

import datetime 
import pyspark.sql.functions as F
df = spark.createDataFrame([([datetime.datetime(2019,11,11,20,8,47), datetime.datetime(2020,10,15,20,8,47), datetime.datetime(2021,9,19,20,8,47)],), ([datetime.datetime(2017,11,5,20,8,47), datetime.datetime(2020,5,5,20,8,47), datetime.datetime(2021,11,11,20,8,47)],)], ['dateTimeValue'])
df.show(truncate=False)

df= df.withColumn('ex', F.explode('dateTimeValue'))
df= df.withColumn('includedFlag', F.when((F.col('ex') < datetime.datetime(2019,12,8,00,00,00)) & (F.col("ex") > datetime.datetime(2018,8,8,00,00,00)) , 1).otherwise(0))
df.groupby('dateTimeValue').agg(F.max('includedFlag').alias('includedFlag')).show(truncate=False)

Output:

+---------------------------------------------------------------+
|dateTimeValue                                                  |
+---------------------------------------------------------------+
|[2019-11-11 20:08:47, 2020-10-15 20:08:47, 2021-09-19 20:08:47]|
|[2017-11-05 20:08:47, 2020-05-05 20:08:47, 2021-11-11 20:08:47]|
+---------------------------------------------------------------+

+---------------------------------------------------------------+------------+
|dateTimeValue                                                  |includedFlag|
+---------------------------------------------------------------+------------+
|[2017-11-05 20:08:47, 2020-05-05 20:08:47, 2021-11-11 20:08:47]|0           |
|[2019-11-11 20:08:47, 2020-10-15 20:08:47, 2021-09-19 20:08:47]|1           |
+---------------------------------------------------------------+------------+

推荐阅读