首页 > 解决方案 > 本地训练的和 Dataproc 训练的 Spark ML 模型之间的不一致

问题描述

我正在将 Spark 从版本 2.3.1 升级到 2.4.5。我正在使用 Dataproc 映像 1.4.27-debian9 在 Google Cloud Platform 的 Dataproc 上使用 Spark 2.4.5 重新训练模型。当我使用 Spark 2.4.5 在本地机器上加载 Dataproc 生成的模型来验证模型时。不幸的是,我收到以下异常:

20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
Exception in thread "main" java.lang.IllegalArgumentException: gbtc_961a6ef213b2 parameter impurity given invalid value variance.

加载模型的代码非常简单:

import org.apache.spark.ml.PipelineModel

object ModelLoad {
  def main(args: Array[String]): Unit = {
    val modelInputPath = getClass.getResource("/model.ml").getPath
    val model = PipelineModel.load(modelInputPath)
  }
}

我按照堆栈跟踪检查1_gbtc_961a6ef213b2/metadata/part-00000模型元数据文件,发现以下内容:

{
    "class": "org.apache.spark.ml.classification.GBTClassificationModel",
    "timestamp": 1590593177604,
    "sparkVersion": "2.4.5",
    "uid": "gbtc_961a6ef213b2",
    "paramMap": {
        "maxIter": 50
    },
    "defaultParamMap": {
        ...
        "impurity": "variance",
        ...
    },
    "numFeatures": 1,
    "numTrees": 50
}

杂质设置为,variance但我的本地 spark 2.4.5 期望它是gini. 为了进行完整性检查,我在本地 spark 2.4.5 上重新训练了模型。模型元impurity数据文件设置为gini.

因此,我检查了 GBT Javadoc 中的 spark 2.4.5 setImpurity 方法。它说The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance."。Dataproc 使用的 spark 2.4.5 似乎与 Apache Spark 文档一致。但是,我从 Maven Central 使用的 Spark 2.4.5 将impurity值设置为gini.

有谁知道为什么 Dataproc 和 Maven Central 中的 Spark 2.4.5 之间存在这种不一致?

我创建了一个简单的训练代码来在本地重现结果:

import java.nio.file.Paths

import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object SimpleModelTraining {
  def main(args: Array[String]) {


    val currentRelativePath = Paths.get("")
    val save_file_location = currentRelativePath.toAbsolutePath.toString

    val spark = SparkSession.builder()
      .config("spark.driver.host", "127.0.0.1")
      .master("local")
      .appName("spark-test")
      .getOrCreate()

    val df: DataFrame = spark.createDataFrame(Seq(
      (0, 0),
      (1, 0),
      (1, 0),
      (0, 1),
      (0, 1),
      (0, 1),
      (0, 2),
      (0, 2),
      (0, 2),
      (0, 3),
      (0, 3),
      (0, 3),
      (1, 4),
      (1, 4),
      (1, 4)
    )).toDF("label", "category")

    val pipeline: Pipeline = new Pipeline().setStages(Array(
      new VectorAssembler().setInputCols(Array("category")).setOutputCol("features"),
      new GBTClassifier().setMaxIter(30)
    ))

    val pipelineModel: PipelineModel = pipeline.fit(df)
    pipelineModel.write.overwrite().save(s"$save_file_location/test_model.ml")
  }
}

谢谢!

标签: scalaapache-sparkgoogle-cloud-dataproc

解决方案


Dataproc 中的 Spark向后移植了针对SPARK-25959的修复程序,该修复程序可能导致本地训练和 Dataproc 训练的 ML 模型之间出现这种不一致。


推荐阅读