scala - 扩展 DefaultParamsReadable 和 DefaultParamsWritable 不允许读取自定义模型
问题描述
再会,
几天来,我一直在努力保存作为大型阶段管道一部分的自定义变压器。我有一个完全由其参数定义的变压器。我有一个估计器,它的 fit 方法将生成一个矩阵,然后相应地设置转换器参数,以便我可以使用 DefaultParamsReadable 和 DefaultParamsReadable 来利用 util.ReadWrite.scala 中已经存在的序列化/反序列化。
我总结的代码如下(包括重要方面):
...
import org.apache.spark.ml.util._
...
// trait to implement in Estimator and Transformer for params
trait NBParams extends Params {
final val featuresCol= new Param[String](this, "featuresCol", "The input column")
setDefault(featuresCol, "_tfIdfOut")
final val labelCol = new Param[String](this, "labelCol", "The labels column")
setDefault(labelCol, "P_Root_Code_Index")
final val predictionsCol = new Param[String](this, "predictionsCol", "The output column")
setDefault(predictionsCol, "NBOutput")
final val ratioMatrix = new Param[DenseMatrix](this, "ratioMatrix", "The transformation matrix")
def getfeaturesCol: String = $(featuresCol)
def getlabelCol: String = $(labelCol)
def getPredictionCol: String = $(predictionsCol)
def getRatioMatrix: DenseMatrix = $(ratioMatrix)
}
// Estimator
class CustomNaiveBayes(override val uid: String, val alpha: Double)
extends Estimator[CustomNaiveBayesModel] with NBParams with DefaultParamsWritable {
def copy(extra: ParamMap): CustomNaiveBayes = {
defaultCopy(extra)
}
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
def setPredictionCol(value: String): this.type = set(predictionsCol, value)
def setRatioMatrix(value: DenseMatrix): this.type = set(ratioMatrix, value)
override def transformSchema(schema: StructType): StructType = {...}
override def fit(ds: Dataset[_]): CustomNaiveBayesModel = {
...
val model = new CustomNaiveBayesModel(uid)
model
.setRatioMatrix(ratioMatrix)
.setFeaturesCol($(featuresCol))
.setLabelCol($(labelCol))
.setPredictionCol($(predictionsCol))
}
}
// companion object for Estimator
object CustomNaiveBayes extends DefaultParamsReadable[CustomNaiveBayes]{
override def load(path: String): CustomNaiveBayes = super.load(path)
}
// Transformer
class CustomNaiveBayesModel(override val uid: String)
extends Model[CustomNaiveBayesModel] with NBParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("customnaivebayes"))
def copy(extra: ParamMap): CustomNaiveBayesModel = {defaultCopy(extra)}
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
def setPredictionCol(value: String): this.type = set(predictionsCol, value)
def setRatioMatrix(value: DenseMatrix): this.type = set(ratioMatrix, value)
override def transformSchema(schema: StructType): StructType = {...}
}
def transform(dataset: Dataset[_]): DataFrame = {...}
}
// companion object for Transformer
object CustomNaiveBayesModel extends DefaultParamsReadable[CustomNaiveBayesModel]
当我将此模型添加为管道的一部分并安装管道时,一切运行正常。当我保存管道时,没有错误。但是,当我尝试在其中加载管道时,出现以下错误:
NoSuchMethodException: $line3b380bcad77e4e84ae25a6bfb1f3ec0d45.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$ $$$6fa979eb27fa6bf89c6b6d1b271932c$$$$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$CustomNaiveBayesModel.read()
为了保存管道,其中包括许多与 NLP 预处理相关的其他转换器,我运行
fittedModelRootCode.write.save("path")
然后加载它(发生故障的地方)我运行
import org.apache.spark.ml.PipelineModel
val fittedModelRootCode = PipelineModel.load("path")
该模型本身似乎运行良好,但每次我希望使用它时,我都无法在数据集上重新训练模型。有谁知道为什么即使使用伴随对象, read() 方法似乎也不可用?
笔记:
- 我在 Databricks Runtime 8.3(Spark 3.1.1,Scala 2.12)上运行
- 我的模型在一个单独的包中,因此在 Spark 外部
- 我已经根据许多现有示例复制了这一点,所有这些示例似乎都可以正常工作,所以我不确定我的代码为什么会失败
- 我知道 Spark ML 中有一个朴素贝叶斯模型,但是,我的任务是进行大量自定义,因此不值得修改现有版本(另外我想学习如何正确处理)
任何帮助将不胜感激。
解决方案
由于您将CustomNaiveBayesModel
伴随对象扩展DefaultParamsReadable
,我认为您应该使用伴随对象CustomNaiveBayesModel
来加载模型。在这里,我编写了一些用于保存和加载模型的代码,它可以正常工作:
import org.apache.spark.SparkConf
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import path.to.CustomNaiveBayesModel
object SavingModelApp extends App {
val spark: SparkSession = SparkSession.builder().config(
new SparkConf()
.setMaster("local[*]")
.setAppName("Test app")
.set("spark.driver.host", "localhost")
.set("spark.ui.enabled", "false")
).getOrCreate()
val training = spark.createDataFrame(Seq(
(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0)
)).toDF("id", "text", "label")
val fittedModelRootCode: PipelineModel = new Pipeline().setStages(Array(new CustomNaiveBayesModel())).fit(training)
fittedModelRootCode.write.save("path/to/model")
val mod = PipelineModel.load("path/to/model")
}
我认为您的错误是PipelineModel.load
用于加载具体模型。
我的环境:
scalaVersion := "2.12.6"
scalacOptions := Seq(
"-encoding", "UTF-8", "-target:jvm-1.8", "-deprecation",
"-feature", "-unchecked", "-language:implicitConversions", "-language:postfixOps")
libraryDependencies += "org.apache.spark" %% "spark-core" % "3.1.1",
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.1.1"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.1.1"
推荐阅读
- python - 无法使用标准方法将 MotorCursor 对象转换为异步函数中的列表
- java - Jodd:不要序列化空对象
- typescript - 从节点模块导出时类型声明使用错误的路径
- javascript - 如何在 React Native 中使用函数渲染组件
- rust - 如何从 Rust 中的 BigInt 中减去 1?
- django - 在基于类的视图中注册表格
- amazon-web-services - 仅在 Aurora 事务提交时调用 Lambda 函数,但保证调用 (ACID)
- python - 在 Linux CLI 中更改 XLS 单元格值而不破坏格式
- javascript - Rails-Jquery Uncaught ReferenceError: $ is not defined
- amazon-web-services - AWS CodeDeploy 优点/缺点