首页 > 解决方案 > 通过修改其私有方法来自定义 Spark ML 估计器(例如 GaussianMixture)的正确方法?

问题描述

我的代码使用了 apache.ml.clustering.GaussianMixture,但是它的 init 方法private def initRandom(...)效果不好,所以我想自定义一个新init方法。

起初我想“扩展” class GuassianMixture,但initRandom它是一种私有方法。

然后我尝试了另一种方法,即设置初始 GMM,但遗憾的是源代码显示TODO: SPARK-15785 Support users 提供了初始 GMM

我也尝试class GuassianMixture为我的自定义类复制代码,但是附加的东西太多了。GaussianMixture.scala带有一些类和特征,其中一些只能在 ML 包中访问。

标签: scalaapache-sparkextendsapache-spark-ml

解决方案


我自己解决了。这是我的解决方案。

我创建了从官方包CustomGaussianMixture扩展的类。GaussianMixtureorg.apache.spark.ml.clustering

在我的项目中,我创建了一个新包,也命名为org.apache.spark.ml.clustering(以防止处理复杂类/特征/对象的范围org.apache.spark.ml.clustering)。并将我的自定义类放入其中。

接下来是重写 method( fit) 调用initRandom,一个非私有方法,所以我可以重写它。具体来说,只要在class里写我新的init方法,把官方源码里面的CustomGaussianMixture方法复制到class里面,记得修改里面的代码调用我自定义的init方法。fitGaussianMixture.scalaCustomGaussianMixtureCustomGaussianMixture.fit()

最后,只需在需要时使用CustomGaussianMixture而不是GaussianMixture


推荐阅读