首页 > 解决方案 > 火花:十进制类型未找到

问题描述

我正在尝试使用 DecimalType(18,2)。以下是我的代码:

import org.apache.spark.sql.types.DataTypes._

object ETL {
  //created a DecimalType
  val decimalType = DataTypes.createDecimalType(18,2)

  case class SKU(price_usd: decimalType)
)

我得到了错误decimalType not found。如何解决?谢谢

顺便说一句,我试过BigDecimal了,它是 (38, 18)。但我需要 (18, 2)。在我的 spark 工作中,我使用 sql 来获取一些属于 (18, 2) 的列。我想写 UDF 来处理它们。我不知道如何在 UDF 中定义小数的日期类型。

标签: scalaapache-sparkdecimaluser-defined-functionsbigdecimal

解决方案


在您的代码中,decimalType 实际上不是 scala 类型标识符 - 它是 DecimalType 类的值。因此,您不能在编译器需要类型标识符的地方使用它。

为了编写 UDF,您可以只使用 java.math.BigDecimal 作为参数类型。无需指定精度和比例。但是,如果您确实需要为 UDF 中的计算设置这些值,您可以尝试在 MathContext 中指定它们。

package HelloSpec.parser

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.types.{DecimalType, StructField, StructType}
import org.scalatest.FlatSpec


case class SKU(price_usd: BigDecimal)

object Fields {
  val PRICE_USD = "price_usd"
}

class TestSo extends FlatSpec with DataFrameSuiteBase with SharedSparkContext {

  import Fields._

  it should "not fail" in {
    import spark.implicits._
    val df = Seq(
      SKU(BigDecimal("1.12")),
      SKU(BigDecimal("1234567890123456.12")),
      SKU(BigDecimal("1234567890123456.123")),
      SKU(BigDecimal("12345678901234567.12"))
    ).toDF

    df.printSchema()
    df.show(truncate = false)

    assert(
      df.schema ==
        StructType(Seq(StructField(name = PRICE_USD, dataType = DecimalType(38, 18))))
    )

    val castedTo18_2 = df.withColumn(PRICE_USD, df(PRICE_USD).cast(DecimalType(18, 2)))
    castedTo18_2.printSchema()
    castedTo18_2.show(truncate = false)
    assert(
      castedTo18_2.schema ==
        StructType(Seq(StructField(name = PRICE_USD, dataType = DecimalType(18, 2))))
    )
    assert {
      castedTo18_2.as[Option[BigDecimal]].collect.toSeq.sorted == Seq(
        // this was 12345678901234567.12 before the cast,
        // but the number with 17 digits before the decimal point exceeded the 18-2=16 allowed digits
        None,
        Some(BigDecimal("1.12")),
        Some(BigDecimal("1234567890123456.12")),
        // note, that 1234567890123456.123 was rounded to 1234567890123456.12
        Some(BigDecimal("1234567890123456.12"))
      )
    }

    import org.apache.spark.sql.functions.{udf, col}
    val processBigDecimal = udf(
      // The argument type has to be java.math.BigDecimal, not scala.math.BigDecimal, which is imported by default
      (bd: java.math.BigDecimal) => {
        if (bd == null) {
          null
        } else {
          s"${bd.getClass} with precision ${bd.precision}, scale ${bd.scale} and value $bd"
        }
      }
    )

    val withUdfApplied = castedTo18_2.
      withColumn("udf_result", processBigDecimal(col(PRICE_USD)))

    withUdfApplied.printSchema()
    withUdfApplied.show(truncate = false)

    assert(
      withUdfApplied.as[(Option[BigDecimal], String)].collect.toSeq.sorted == Seq(
        None -> null,
        Some(BigDecimal("1.12")) -> "class java.math.BigDecimal with precision 19, scale 18 and value 1.120000000000000000",
        Some(BigDecimal("1234567890123456.12")) -> "class java.math.BigDecimal with precision 34, scale 18 and value 1234567890123456.120000000000000000",
        Some(BigDecimal("1234567890123456.12")) -> "class java.math.BigDecimal with precision 34, scale 18 and value 1234567890123456.120000000000000000"
      )
    )
  }
}

推荐阅读