首页 > 解决方案 > 使用带有 pyspark 数据框的 h3 库

问题描述

我有一个看起来像这样的火花数据框:

+-----------+-----------+-------+------------------+----------+--------+--------+--------+--------+
|client_id_x|client_id_y|   dist|              time|      date|   lat_y|   lng_y|   lat_x|   lng_x|
+-----------+-----------+-------+------------------+----------+--------+--------+--------+--------+
| 0700014578| 0700001710|13125.7|21.561666666666667|2021-06-07|-23.6753|-46.6788|-23.5933|-46.6382|
| 0700014578| 0700001760| 8447.8|13.103333333333333|2021-06-07|-23.6346|-46.6057|-23.5933|-46.6382|
| 0700014578| 0700002137| 9681.1|16.173333333333332|2021-06-07|-23.6309|-46.7059|-23.5933|-46.6382|
+-----------+-----------+-------+------------------+----------+--------+--------+--------+--------+

我想做的是基于 H3 地理空间索引系统获取 lat,lng 唯一标识符。为此,我尝试使用以下代码:

def get_geo_id(df: pd.DataFrame) -> pd.Series:
    return df.apply(lambda x: h3.geo_to_h3(x[lat_name], x[lng_name], resolution = 13))
    
get_geo_udf = pandas_udf(get_geo_id, returnType=StringType())

# calling function
new_df.withColumn("id_h3_x", get_geo_udf(new_df.select(["lat_x", "lng_x"])))    

但是,我收到以下错误:

TypeError: Invalid argument, not a string or column: DataFrame[lat_x: double, lng_x: double] of type <class 'pyspark.sql.dataframe.DataFrame'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

我也试过这个:

def get_geo_id(lat_name: pd.Series, lng_name: pd.Series) -> pd.Series:
    return h3.geo_to_h3(lat_name, lng_name, resolution = 13)
    
get_geo_udf = pandas_udf(get_geo_id, returnType = StringType())

new_df.withColumn("id_h3_x", get_geo_udf(new_df["lat_x"], new_df["lng_x"])).show() 

但它显示了这个错误:

TypeError: cannot convert the series to <class 'float'>

我是 spark 的新手,所以我不太确定我遇到的错误。我将衷心感谢您的帮助。

标签: pythondataframeapache-sparkpysparkh3

解决方案


我设法解决了这个问题。我不得不使用以下功能

@pandas_udf("client_id_y string, client_id_x string, dist double, time double, date string, lat_x double, lng_x double, lat_y double, lng_y double, geoid_x string, geoid_y string", PandasUDFType.GROUPED_MAP)
def get_geo_id(df):
    df["geoid_x"] = df.apply(lambda x: h3.geo_to_h3(x.lat_x, x.lng_x, resolution = 13), axis = 1)
    df["geoid_y"] = df.apply(lambda x: h3.geo_to_h3(x.lat_y, x.lng_y, resolution = 13), axis = 1)
    return df

# call the function
h3_dff = new_df.groupby("client_id_x").apply(get_geo_id)
h3_dff.show()

结果数据框是:

+-----------+-----------+-------+------------------+----------+--------+--------+--------+--------+---------------+---------------+
|client_id_y|client_id_x|   dist|              time|      date|   lat_x|   lng_x|   lat_y|   lng_y|        geoid_x|        geoid_y|
+-----------+-----------+-------+------------------+----------+--------+--------+--------+--------+---------------+---------------+
| 0700001710| 0700014578|13125.7|21.561666666666667|2021-06-07|-23.5933|-46.6382|-23.6753|-46.6788|8da8100e225e4ff|8da81000890577f|
| 0700001760| 0700014578| 8447.8|13.103333333333333|2021-06-07|-23.5933|-46.6382|-23.6346|-46.6057|8da8100e225e4ff|8da81001b0b353f|
| 0700002137| 0700014578| 9681.1|16.173333333333332|2021-06-07|-23.5933|-46.6382|-23.6309|-46.7059|8da8100e225e4ff|8da810056a5673f|

这正是我想要的。


推荐阅读