首页 > 解决方案 > 将带有自定义转换器的 Pyspark PipelineModel 导入 Scala

问题描述

我最近创建了一个带有一些自定义转换器的 pyspark PipelineModel,以生成原生 Spark 转换器无法实现的功能。这是我的一个变压器的例子。它接受一个字符串标签输入并返回输入的超类标签:

class newLabelMap(
    Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable,
):
    inputCol = Param(Params._dummy(),"inputCol","The input column",TypeConverters.toString)
    outputCol = Param(Params._dummy(),"outputCol","The output column",TypeConverters.toString)

    def __init__(self, inputCol = "", outputCol=""):
        super(newLabelMap, self).__init__()
        self._setDefault(inputCol="")
        self._setDefault(outputCol="")
        self._set(inputCol=inputCol)
        self._set(outputCol=outputCol)


    def getInputCol(self):
        return self.getOrDefault(self.inputCol)

    def setInputCol(self, inputCol):
        self._set(inputCol=inputCol)

    def getOutputCol(self):
        return self.getOrDefault(self.outputCol)

    def setOutputCol(self, outputCol):
        self._set(outputCol=outputCol)

    def _transform(self, dataset):
        @udf("string")
        def findLabel(labelVal):
            new_label_dict = {'oldLabel0' : 'newLabel0',
                          'oldLabel1' : 'newLabel1',
                          'oldLabel2' : 'newLabel1',
                          'oldLabel3' : 'newLabel1',
                          'oldLabel4' : 'newLabel2',
                          'oldLabel5' : 'newLabel2',
                          'oldLabel6' : 'newLabel2',
                          'oldLabel7' : 'newLabel3',
                          'oldLabel8' : 'newLabel3',
                          'oldLabel9' : 'newLabel4',
                          'oldLabel10' : 'newLabel4'}

            try:
                labelKey = new_label_dict[labelVal]
                return labelKey
            except:
                return 'other'

        out_col = self.getOutputCol()
        in_col = dataset[self.getInputCol()]
        return dataset.withColumn(out_col, findLabel(in_col))

转换器在管道中工作正常,我可以保存它,将其加载回 pyspark 会话,然后转换没有任何问题。当我尝试将其导入 scala 环境时,问题就出现了。当我尝试加载模型时,我收到此错误输出:

Name: java.lang.IllegalArgumentException
Message: requirement failed: Error loading metadata: Expected class name org.apache.spark.ml.PipelineModel but found class name pyspark.ml.pipeline.PipelineModel
StackTrace:   at scala.Predef$.require(Predef.scala:224)
  at org.apache.spark.ml.util.DefaultParamsReader$.parseMetadata(ReadWrite.scala:638)
  at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:616)
  at org.apache.spark.ml.Pipeline$SharedReadWrite$.load(Pipeline.scala:267)
  at org.apache.spark.ml.PipelineModel$PipelineModelReader.load(Pipeline.scala:348)
  at org.apache.spark.ml.PipelineModel$PipelineModelReader.load(Pipeline.scala:342)

如果我删除自定义转换器,它在 Scala 中加载得很好,所以我很好奇如何能够使用用 pyspark 编写的自定义转换器,可以在 PipelineModel 中移植到 Scala 环境?我需要以任何方式附加我的代码吗?任何帮助是极大的赞赏 :)

标签: python-3.xscalaapache-sparkpysparkpipeline

解决方案


推荐阅读