首页 > 解决方案 > 如何在 DataFrame 中的列中找到与其他 DataFrame 中的另一列最接近的值?

问题描述

我有两个从两个 csv 文件中读取的数据帧。

机场DF 150kb 7000 条记录

iata_code   latitude    longitude
AAA -17.352606  -145.509956
AAB -26.69317   141.0478
AAC 31.07333    33.83583

userDF ~75MB ~100 万条记录

uuid    geoip_latitude  geoip_longitude
DDEFEBEA-98ED-49EB-A4E7-9D7BFDB7AA0B    -37.8333015441895   145.050003051758
DAEF2221-14BE-467B-894A-F101CDCC38E4    52.5167007446289    4.66669988632202
31971B3E-2F80-4F8D-86BA-1F2077DF36A2    35.685001373291 139.751403808594

我想根据地理距离找到离用户最近的机场。

输出应该有两列UUID 和对应的 iata_code

我有计算地理距离的haversine效用函数

def distance(
      startLon: Double,
      startLat: Double,
      endLon: Double,
      endLat: Double,
      R: Double
  ): Double = {
    val dLat = math.toRadians(endLat - startLat)
    val dLon = math.toRadians(endLon - startLon)
    val lat1 = math.toRadians(startLat)
    val lat2 = math.toRadians(endLat)

    val a =
      math.sin(dLat / 2) * math.sin(dLat / 2) +
        math.sin(dLon / 2) * math.sin(dLon / 2) * math.cos(lat1) * math.cos(lat2)
    val c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

    R * c
  }

编辑:

userDF
 |-- uuid: string (nullable = true)
 |-- geoip_latitude: double (nullable = true)
 |-- geoip_longitude: double (nullable = true)

airportDF
 |-- iata_code: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)


transformations(spark, userDF, airportDF).show()

def transformations(spark: SparkSession, userDF: DataFrame, airportDF: DataFrame) = {
    val airports = broadcastDF(spark, airportDF)
    userDF.transform(findNearestAirport(spark, airports.value))
  }

  def broadcastDF(spark: SparkSession, df: DataFrame) = {
    spark.sparkContext.broadcast(df.collect())
  }

  def findNearestAirport(spark: SparkSession, airports: Array[Row])(
    userDF: DataFrame
  ): DataFrame = {
    import spark.implicits._

    var distance = Double.MaxValue
    var minDistance = Double.MaxValue
    var nearestAirportID = ""

    userDF.flatMap { user =>
      airports.foreach { airport =>
        distance = Haversine.distance(
          user.getAs[Double]("geoip_longitude"),
          user.getAs[Double]("geoip_latitude"),
          airport.getAs[Double]("longitude"),
          airport.getAs[Double]("latitude")
        )
        if (minDistance > distance) {
          minDistance = distance
          nearestAirportID = airport.getAs[String]("iata_code")
        }
      }
      println(s"User ${user.getAs[String]("uuid")} is closest to airport $nearestAirportID")
      Seq((user.getAs[String]("uuid"), nearestAirportID))
    }.toDF("uuid", "iata_code")
  }

所以我完成了代码,但有几个问题。

  1. 我使用了 DF.transform 函数而不是 UDF。是更好还是一样?
  2. 互联网上的大多数/所有广播示例都具有类似地图的结构/json/case 类。我只是按原样用DF广播。一个比另一个有什么优势/劣势。
  3. 有什么办法可以改进代码吗?
  4. 这是一个足够好的可扩展解决方案吗?我选择自己使用 spark,就好像数据在流式传输一样,它也可以轻松处理。考虑到每秒可能有成百上千的事件,如果不使用像 Spark 这样的流/批处理引擎,(在 Scala 中)还有什么其他可扩展的选项?

标签: scalaapache-spark

解决方案


  1. 首先,使用 uuid、lat/long 列合并成一个元组struct

    import org.apache.spark.sql.functions.struct
    airportDF.withColumn("uuid_lat_long_struct", struct(airportDF("uuid"),airportDF("geoip_lat"), airportDF("geoip_long"))
    
  2. 用于collect_set将 airportDF 中的每一行折叠到一个数组/列表并创建一个新的数据框airportFlattenedDF。例如

    airportDF.groupBy('uuid').agg(collect_list('uuid_lat_long_struct'))
    
  3. 使用这个 airportDFcollect()在驱动程序端进入一个元组数组,然后broadcast到所有执行程序节点。

  4. 编写一个 lambda 或 UDF 来获取 userDF 的每条记录,并将其与数组中的每个元素进行比较,然后选择距离最小的元素并从airportFlattenedDF

  5. 您可以做的另一个优化是使用flat-earth公式来减少三角函数的数量,因为您只关心相对距离。


推荐阅读