scala - 本地训练的和 Dataproc 训练的 Spark ML 模型之间的不一致
问题描述
我正在将 Spark 从版本 2.3.1 升级到 2.4.5。我正在使用 Dataproc 映像 1.4.27-debian9 在 Google Cloud Platform 的 Dataproc 上使用 Spark 2.4.5 重新训练模型。当我使用 Spark 2.4.5 在本地机器上加载 Dataproc 生成的模型来验证模型时。不幸的是,我收到以下异常:
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
Exception in thread "main" java.lang.IllegalArgumentException: gbtc_961a6ef213b2 parameter impurity given invalid value variance.
加载模型的代码非常简单:
import org.apache.spark.ml.PipelineModel
object ModelLoad {
def main(args: Array[String]): Unit = {
val modelInputPath = getClass.getResource("/model.ml").getPath
val model = PipelineModel.load(modelInputPath)
}
}
我按照堆栈跟踪检查1_gbtc_961a6ef213b2/metadata/part-00000
模型元数据文件,发现以下内容:
{
"class": "org.apache.spark.ml.classification.GBTClassificationModel",
"timestamp": 1590593177604,
"sparkVersion": "2.4.5",
"uid": "gbtc_961a6ef213b2",
"paramMap": {
"maxIter": 50
},
"defaultParamMap": {
...
"impurity": "variance",
...
},
"numFeatures": 1,
"numTrees": 50
}
杂质设置为,variance
但我的本地 spark 2.4.5 期望它是gini
. 为了进行完整性检查,我在本地 spark 2.4.5 上重新训练了模型。模型元impurity
数据文件设置为gini
.
因此,我检查了 GBT Javadoc 中的 spark 2.4.5 setImpurity 方法。它说The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance."
。Dataproc 使用的 spark 2.4.5 似乎与 Apache Spark 文档一致。但是,我从 Maven Central 使用的 Spark 2.4.5 将impurity
值设置为gini
.
有谁知道为什么 Dataproc 和 Maven Central 中的 Spark 2.4.5 之间存在这种不一致?
我创建了一个简单的训练代码来在本地重现结果:
import java.nio.file.Paths
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
object SimpleModelTraining {
def main(args: Array[String]) {
val currentRelativePath = Paths.get("")
val save_file_location = currentRelativePath.toAbsolutePath.toString
val spark = SparkSession.builder()
.config("spark.driver.host", "127.0.0.1")
.master("local")
.appName("spark-test")
.getOrCreate()
val df: DataFrame = spark.createDataFrame(Seq(
(0, 0),
(1, 0),
(1, 0),
(0, 1),
(0, 1),
(0, 1),
(0, 2),
(0, 2),
(0, 2),
(0, 3),
(0, 3),
(0, 3),
(1, 4),
(1, 4),
(1, 4)
)).toDF("label", "category")
val pipeline: Pipeline = new Pipeline().setStages(Array(
new VectorAssembler().setInputCols(Array("category")).setOutputCol("features"),
new GBTClassifier().setMaxIter(30)
))
val pipelineModel: PipelineModel = pipeline.fit(df)
pipelineModel.write.overwrite().save(s"$save_file_location/test_model.ml")
}
}
谢谢!
解决方案
Dataproc 中的 Spark向后移植了针对SPARK-25959的修复程序,该修复程序可能导致本地训练和 Dataproc 训练的 ML 模型之间出现这种不一致。
推荐阅读
- angular - Angular 显示默认页面。怎么修?
- reactjs - 如何使用 reactjs 将值存储在组件外部的一个变量中
- php - 尝试在 localhost php 中发送电子邮件时出错
- python - CVXPY 是否支持 trace(XT@A@X),其中 X、A 都是矩阵?
- javascript - 如何解决javascript中的异步等待问题?
- javascript - 从 promise 导出变量然后阻塞
- node.js - 谷歌语音到文本无法在 nodejs 上运行
- java - 获取方法中的所有变量名
- python - python从朋友那里接收数据
- c - 在 linux 内核模块中克隆一个文件