首页 > 解决方案 > 迭代列Scala的元素

问题描述

我有一个由两个双精度数组组成的数据框。我想创建一个新列,它是将欧几里德距离函数应用于前两列的结果,即如果我有:

 A      B 
(1,2)  (1,3)
(2,3)  (3,4)

创造:

 A      B     C
(1,2)  (1,3)  1
(2,3)  (3,4)  1.4

我的数据架构是:

df.schema.foreach(println)
StructField(col1,ArrayType(DoubleType,false),false)
StructField(col2,ArrayType(DoubleType,false),true)

每当我调用这个距离函数时:

def distance(xs: Array[Double], ys: Array[Double]) = {
  sqrt((xs zip ys).map { case (x,y) => pow(y - x, 2) }.sum)
}

我收到一个类型错误:

df.withColumn("distances" , distance($"col1",$"col2"))
<console>:68: error: type mismatch;
 found   : org.apache.spark.sql.ColumnName
 required: Array[Double]
       ids_with_predictions_centroids3.withColumn("distances" , distance($"col1",$"col2"))

我知道我必须遍历每一列的元素,但我无法在任何地方找到如何执行此操作的解释。我对 Scala 编程非常陌生。

标签: scalaapache-sparkapache-spark-sql

解决方案


要在数据框上使用自定义函数,您需要将其定义为UDF. 例如,可以这样做:

val distance = udf((xs: WrappedArray[Double], ys: WrappedArray[Double]) => {
  math.sqrt((xs zip ys).map { case (x,y) => math.pow(y - x, 2) }.sum)
})

df.withColumn("C", distance($"A", $"B")).show()

请注意,此处需要使用WrappedArray(或)。Seq

结果数据框:

+----------+----------+------------------+
|         A|         B|                 C|
+----------+----------+------------------+
|[1.0, 2.0]|[1.0, 3.0]|               1.0|
|[2.0, 3.0]|[3.0, 4.0]|1.4142135623730951|
+----------+----------+------------------+

推荐阅读