首页 > 解决方案 > Flink Scala - 比较方法违反了它的一般合同

问题描述

我正在 Flink 中编写一个项目,该项目涉及在批处理数据上流式传输一组查询点并执行完整的顺序扫描以查找最近的邻居。对单个 Float 值进行简单排序操作会引发违反一般合约错误。主要方法定义为:

object StreamingDeCP{
  var points: Vector[Point] = _

  def main(args: Array[String]): Unit = {
    val queryPointsVec: Vector[Point] = ... // Read from file
    val pointsVec: Vector[Point] = ...      // Read from file

    val streamEnv: StreamExecutionEnvironment = 
                   StreamExecutionEnvironment.getExecutionEnvironment
    val queryPoints = streamEnv.fromCollection(queryPointsVec)

    points = pointsVec
    queryPoints.map(new StreamingSequentialScan)

    streamEnv.execute("StreamingDeCP")
  }

  final class StreamingSequentialScan 
                    extends MapFunction[Point, (Point, Vector[Point])] {

    def map(queryPoint: Point): (Point, Vector[Point]) = {
      val nn = points
                .map{ _.eucDist(queryPoint) }
                .sorted

      (queryPoint, nn)
    }
  }
}

Point和伴随对象是:

case class Point(pointID: Long,
                 descriptor: Vector[Float]) extends Serializable {
  var distance: Float = Float.MaxValue

  def eucDist(that: Point): Point = {
    // Simple arithmetic to calculate and set the distance variable
  }
}

object Point{
  implicit def orderByDistance[A <: Point]: Ordering[A] =
    Ordering.by(_.distance)
}

为了查明原因,这里有一些关于我尝试过的事情的注释:

我还注意到,执行相同的代码并不总是可靠地重现错误。我正在Vector[Points]以完全确定的方式阅读,因此导致这种行为的唯一可能原因必须是 Flink 调度程序或排序方法中的一些有状态计算。

关于同一主题的其他帖子似乎涉及自定义比较器中的错过场景,但这应该是对单个浮点值的简单排序操作,所以我不知道是什么导致了这个问题。

标签: javascalaflink-streaming

解决方案


我不熟悉 Flink,但我没有任何理由假设它会以顺序单线程的方式执行每一个令人尴尬的并行任务。 MapFunction

由于您的Pointcontains vars,并且这些s 在 s 的方法中var发生了变异,因此代码必须以“比较方法违反其一般合同”而失败 -每当使用 parallelism 执行时都会出现异常。mapMapFunctionMapFunction!= 1

为了避免函数内部的任何副作用map,您可以修改代码如下:

  • 从 中删除任何vars main,使其points成为不可变的val
  • 删除任何类型的varsPoint
  • 实现方法

    def eucDist(other: Point): Double
    

    这只是计算到另一个点的距离(不改变任何东西)。

  • 使用sortBy

    val nn = points.sortBy(_.eucDist(queryPoint))
    

或者,如果您想避免在排序期间多次重新计算欧几里得距离,请预先计算一次距离,排序,然后丢弃距离:

val nn = points.map(p => (p, p.eucDist(queryPoint))).sortBy(_._2).map(_._1)

推荐阅读