首页 > 解决方案 > Add new column in Pyspark dataframe based on where condition on other column

问题描述

I have a Pyspark data frame as follows:

+------------+-------------+--------------------+
|package_id  | location    | package_scan_code  | 
+------------+-------------+--------------------+
|123         | Denver      |05                  |  
|123         | LosAngeles  |03                  |  
|123         | Dallas      |09                  |  
|123         | Vail        |02                  | 
|456         | Jacksonville|05                  |  
|456         | Nashville   |09                  |
|456         | Memphis     |03                  |

"package_scan_code" 03 represents the origin of the package.

I want to add a column "origin" to this dataframe such that for each package (identified by "package_id"), the values in the newly added origin column would be the same location that corresponds to "package_scan_code" 03.

In the above case, there are two unique packages 123 and 456, and they have origins as LosAngeles and Memphis respectively (corresponding to package_scan_code 03).

So I want my output to be as follows:

+------------+-------------+--------------------+------------+
| package_id |location     | package_scan_code  |origin      |
+------------+-------------+--------------------+------------+
|123         | Denver      |05                  | LosAngeles |
|123         | LosAngeles  |03                  | LosAngeles |
|123         | Dallas      |09                  | LosAngeles |
|123         | Vail        |02                  | LosAngeles |
|456         | Jacksonville|05                  |  Memphis   |
|456         | Nashville   |09                  |  Memphis   |
|456         | Memphis     |03                  |  Memphis   |

How can I achieve this in Pyspark? I tried .withColumn method, but I could not get the condition right.

标签: pythonapache-sparkpysparkapache-spark-sqlpyspark-sql

解决方案


Filter the data frame by package_scan_code == '03' and then join back with the original data frame:

(df.filter(df.package_scan_code == '03')
   .selectExpr('package_id', 'location as origin')
   .join(df, ['package_id'], how='right')
   .show())
+----------+----------+------------+-----------------+
|package_id|    origin|    location|package_scan_code|
+----------+----------+------------+-----------------+
|       123|LosAngeles|      Denver|               05|
|       123|LosAngeles|  LosAngeles|               03|
|       123|LosAngeles|      Dallas|               09|
|       123|LosAngeles|        Vail|               02|
|       456|   Memphis|Jacksonville|               05|
|       456|   Memphis|   Nashville|               09|
|       456|   Memphis|     Memphis|               03|
+----------+----------+------------+-----------------+

Note: this assumes you have at most one package_scan_code equal to 03 per package_id, otherwise the logic wouldn't be correct and you need to rethink how origin should be defined.


推荐阅读