首页 > 解决方案 > 如何在 Spark Scala 中的 Double 上使用 sqrt

问题描述

我正在尝试在 Spark (Scala 2.11) 上手动计算均方根误差 (RMSE)

瑟

如上面的截图,我计算每一行的平方误差(SE)

val predicted_with_sqr_err = predicted.withColumn("se", pow(($"medianHouseValue" - $"prediction"), lit(2)))

然后我计算均方误差(MSE)

val sum_se = predicted_with_sqr_err.agg(sum("se")).first.get(0)
val sum_se_double = sum_se.toString.toDouble
val mean_sqr_err = (1.0/predicted_with_sqr_err.count)*sum_se_double 

它工作得很好。但是当我尝试平方根来计算均方根误差(RMSE)时。

val root_mean_sqr_err = sqrt(mean_sqr_err)

它给出错误:

<console>:83: error: overloaded method value sqrt with alternatives:
  (colName: String)org.apache.spark.sql.Column <and>
  (e: org.apache.spark.sql.Column)org.apache.spark.sql.Column
 cannot be applied to (Double)
       val root_mean_sqr_err = sqrt(mean_sqr_err)

sqrt 错误

我应该如何解决?

标签: scalaapache-sparkapache-zeppelin

解决方案


问题是您使用sqrt的是在Spark SQL. 此函数应仅用作 Spark SQL DSL 的一部分(在选择、聚合等中)。它采用ColumnorString作为参数,但您试图传递Double. 而是使用包sqrt中定义的函数scala.math

val root_mean_sqr_err = math.sqrt(mean_sqr_err)

推荐阅读