首页 > 解决方案 > 在 PySpark 中向 MultilayerPerceptronClassifier 类添加预测阈值

问题描述

我正在尝试MultilayerPerceptronClassifier使用交叉验证优化 in (Py)Spark 的预测阈值。我试图创建一个MultilayerPerceptronClassifier 实际上允许提供阈值的子类。它似乎可以正常工作Pipeline,但是每当我将其插入 aCrossValidator时,它都会给出错误消息。

我上的课:

class MLP(MultilayerPerceptronClassifier, HasThresholds):

    def __init__(self, thresholds=None, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.setParams(thresholds=thresholds, **kwargs)

    def setParams(self, thresholds=None, **kwargs):
        return self._set(thresholds=thresholds, **kwargs)

样本数据(labeled):

+---------+-----+------+------------------------------+
|family_id|label|weight|                      embedded|
+---------+-----+------+------------------------------+
| 60009405|  1.0|   1.0|[0.10171283965701926,0.0415...|
| 55022499|  1.0|   1.0|[0.15376672673361091,-0.001...|
| 63938820|  1.0|   1.0|[0.16867649792968614,0.0126...|
| 37452877|  1.0|   1.0|[0.18771651450592225,0.0191...|
| 64559476|  1.0|   1.0|[0.1504634794488278,-0.0032...|
| 59544896|  0.0|  1.25|[0.12911133907668226,0.0116...|
| 46383793|  0.0|  1.25|[0.13390121417649795,-0.013...|
| 59473587|  0.0|  1.25|[0.1262944439844325,0.01176...|
| 63938820|  0.0|  1.25|[0.16867649792968614,0.0126...|
+---------+-----+------+------------------------------+

这似乎工作正常:

mlp = MLP(featuresCol='embedded', layers=[200, 10, 2], thresholds=[1e-20, 1-1e-20])
pipe = Pipeline(stages=[mlp])
model = pipe.fit(labeled)
model.transform(labeled).show(10)
+---------+-----+------+--------------------+--------------------+--------------------+----------+
|family_id|label|weight|            embedded|       rawPrediction|         probability|prediction|
+---------+-----+------+--------------------+--------------------+--------------------+----------+
| 60009405|  1.0|   1.0|[0.10171283965701...|[-11.937067045534...|[6.74683311024104...|       0.0|
| 55022499|  1.0|   1.0|[0.15376672673361...|[-11.914377530833...|[7.32793349054270...|       0.0|
| 63938820|  1.0|   1.0|[0.16867649792968...|[-0.5160228904601...|[0.50001341804946...|       0.0|
| 37452877|  1.0|   1.0|[0.18771651450592...|[-10.034360656260...|[4.62078113096099...|       0.0|
| 64559476|  1.0|   1.0|[0.15046347944882...|[-11.971196504198...|[6.19667960173464...|       0.0|
| 59544896|  0.0|  1.25|[0.12911133907668...|[10.5489426088559...|[0.99999999980450...|       0.0|
| 46383793|  0.0|  1.25|[0.13390121417649...|[10.6067487531592...|[0.99999999982723...|       0.0|
| 59473587|  0.0|  1.25|[0.12629444398443...|[10.5199541406352...|[0.99999999979221...|       0.0|
| 63938820|  0.0|  1.25|[0.16867649792968...|[-0.5160228904601...|[0.50001341804946...|       0.0|
+---------+-----+------+--------------------+--------------------+--------------------+----------+

请注意,我非常设置阈值以表明模型总是使用这些阈值预测 0。

现在,以下内容不起作用:

mlp = MLP(featuresCol='embedded', layers=[200, 10, 2])
pipe = Pipeline(stages=[mlp])
grid = ParamGridBuilder().\
    addGrid(mlp.thresholds, [[0.3, 0.7], [0.7, 0.3]]).\
    build()
cv = CrossValidator(estimator=pipe,valuator=MulticlassClassificationEvaluator(metricName='f1'),numFolds=2,estimatorParamMaps=grid,parallelism=len(grid))
model = cv.fit(labeled)

error message:
Traceback (most recent call last):
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-190-27dfb1e1d326>", line 1, in <module>
    cv.fit(labeled)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/base.py", line 132, in fit
    return self._fit(dataset)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/tuning.py", line 303, in _fit
    tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/tuning.py", line 49, in _parallelFitTasks
    modelIter = est.fitMultiple(train, epm)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/base.py", line 103, in fitMultiple
    estimator = self.copy()
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/pipeline.py", line 128, in copy
    stages = [stage.copy(extra) for stage in that.getStages()]
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/pipeline.py", line 128, in <listcomp>
    stages = [stage.copy(extra) for stage in that.getStages()]
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/wrapper.py", line 262, in copy
    that._transfer_params_to_java()
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/wrapper.py", line 124, in _transfer_params_to_java
    pair = self._make_java_param_pair(param, self._paramMap[param])
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/ml/wrapper.py", line 115, in _make_java_param_pair
    return java_param.w(java_value)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/Users/thijsvandepoll/PycharmProjects/focusbv/venv/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)
py4j.protocol.Py4JJavaError: An error occurred while calling o8256.w.
: java.lang.NullPointerException
    at scala.collection.mutable.ArrayOps$ofDouble$.length$extension(ArrayOps.scala:276)
    at scala.collection.mutable.ArrayOps$ofDouble.length(ArrayOps.scala:276)
    at scala.collection.IndexedSeqOptimized$class.prefixLengthImpl(IndexedSeqOptimized.scala:38)
    at scala.collection.IndexedSeqOptimized$class.forall(IndexedSeqOptimized.scala:43)
    at scala.collection.mutable.ArrayOps$ofDouble.forall(ArrayOps.scala:270)
    at org.apache.spark.ml.param.shared.HasThresholds$$anonfun$2.apply(sharedParams.scala:201)
    at org.apache.spark.ml.param.shared.HasThresholds$$anonfun$2.apply(sharedParams.scala:201)
    at org.apache.spark.ml.param.Param.validate(params.scala:72)
    at org.apache.spark.ml.param.ParamPair.<init>(params.scala:656)
    at org.apache.spark.ml.param.Param.$minus$greater(params.scala:87)
    at org.apache.spark.ml.param.Param.w(params.scala:83)
    at sun.reflect.GeneratedMethodAccessor66.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

好像找不到参数thresholds。我不知道如何解决这个问题。

有人能帮我吗?

标签: machine-learningpysparkclassificationcross-validationapache-spark-ml

解决方案


取决于您使用的 spark 版本,此功能是从spark 3.0.0实现的,在spark 版本 < 3.0.0中此功能不存在,这就是您确实收到错误的原因(来自 spark 3.1.1 的很多不平衡数据集问题将很容易解决)。

这是您可以在 spark < 3.0.0 中访问的内容:

model_fit.bestModel.thresholds
Param(parent='MultilayerPerceptronClassifier_xxxxxx', name='thresholds', doc="Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold")

在火花> = 3.0.0:

setThresholds(value)
Sets the value of thresholds.

3.1.0 版中的新功能可以调整并添加了更多参数,除了 MultilayerPerceptronClassificationModel之外还有 MultilayerPerceptronClassificationSummary

更多详细信息可以在发布中找到:

火花释放-3-0-0

MLlib
Highlight

Multiple columns support was added to Binarizer (SPARK-23578), StringIndexer (SPARK-11215), StopWordsRemover (SPARK-29808) and PySpark QuantileDiscretizer (SPARK-22796)
Support Tree-Based Feature Transformation(SPARK-13677)
Two new evaluators MultilabelClassificationEvaluator (SPARK-16692) and RankingEvaluator (SPARK-28045) were added
Sample weights support was added in DecisionTreeClassifier/Regressor (SPARK-19591), RandomForestClassifier/Regressor (SPARK-9478), GBTClassifier/Regressor (SPARK-9612), RegressionEvaluator (SPARK-24102), BinaryClassificationEvaluator (SPARK-24103), BisectingKMeans (SPARK-30351), KMeans (SPARK-29967) and GaussianMixture (SPARK-30102)
R API for PowerIterationClustering was added (SPARK-19827)
Added Spark ML listener for tracking ML pipeline status (SPARK-23674)
Fit with validation set was added to Gradient Boosted Trees in Python (SPARK-24333)
RobustScaler transformer was added (SPARK-28399)
Factorization Machines classifier and regressor were added (SPARK-29224)
Gaussian Naive Bayes (SPARK-16872) and Complement Naive Bayes (SPARK-29942) were added
ML function parity between Scala and Python (SPARK-28958)
predictRaw is made public in all the Classification models. predictProbability is made public in all the Classification models except LinearSVCModel (SPARK-30358)
Changes of behavior

Please read the migration guide for details.

A few other behavior changes that are missed in the migration guide:

In Spark 3.0, a multiclass logistic regression in Pyspark will now (correctly) return LogisticRegressionSummary, not the subclass BinaryLogisticRegressionSummary. The additional methods exposed by BinaryLogisticRegressionSummary would not work in this case anyway. (SPARK-31681)
In Spark 3.0, pyspark.ml.param.shared.Has* mixins do not provide any set(self, value) setter methods anymore, use the respective self.set(self., value) instead. See SPARK-29093 for details. (SPARK-29093)
   PySpark
    Project Zen
    
    Project Zen: Improving Python usability (SPARK-32082)
    PySpark type hints support (SPARK-32681)
    Redesign PySpark documentation (SPARK-31851)
    Migrate to NumPy documentation style (SPARK-32085)
    Installation option for PyPI Users (SPARK-32017)
    Un-deprecate inferring DataFrame schema from list of dict (SPARK-32686)
    Simplify the exception message from Python UDFs (SPARK-33407)
    Other Notable Changes
    
    Stage Level Scheduling APIs (SPARK-29641)
    Deduplicate deterministic PythonUDF calls (SPARK-33303)
    Support higher order functions in PySpark functions(SPARK-30681)
    Support data source v2x write APIs (SPARK-29157)
    Support percentile_approx in PySpark functions(SPARK-30569)
    Support inputFiles in PySpark DataFrame (SPARK-31763)
    Support withField in PySpark Column (SPARK-32835)
    Support dropFields in PySpark Column (SPARK-32511)
    Support nth_value in PySpark functions (SPARK-33020)
    Support acosh, asinh and atanh (SPARK-33563)
    Support getCheckpointDir method in PySpark SparkContext (SPARK-33017)
    Support to fill nulls for missing columns in unionByName (SPARK-32798)
    Update cloudpickle to v1.5.0 (SPARK-32094)
    Add MapType support for PySpark with Arrow (SPARK-24554)
    DataStreamReader.table and DataStreamWriter.toTable (SPARK-33836)
    Changes of behavior
    
    Please read the migration guides for PySpark.
    
    Programming guides: PySpark Getting Started and PySpark User Guide.
    
    Structured Streaming
    Performance Enhancements
    
    Cache fetched list of files beyond maxFilesPerTrigger as unread file (SPARK-30866)
    Streamline the logic on file stream source and sink metadata log (SPARK-30462)
    Avoid reading compact metadata log twice if the query restarts from compact batch (SPARK-30900)
    Feature Enhancements
    
    Add DataStreamReader.table API (SPARK-32885)
    Add DataStreamWriter.toTable API (SPARK-32896)
    Left semi stream-stream join (SPARK-32862)
    Full outer stream-stream join (SPARK-32863)
    Provide a new option to have retention on output files (SPARK-27188)
    Add Spark Structured Streaming History Server Support (SPARK-31953)
    Introduce State schema validation among query restart (SPARK-27237)
    Other Notable Changes
    
    Introduce schema validation for streaming state store (SPARK-31894)
    Support to use a different compression codec in state store (SPARK-33263)
    Kafka connector infinite wait because metadata never updated (SPARK-28367)
    Upgrade Kafka to 2.6.0 (SPARK-32568)
    Pagination support for Structured Streaming UI pages (SPARK-31642, SPARK-30119)
    State information in Structured Streaming UI (SPARK-33223)
    Watermark gap information in Structured Streaming UI (SPARK-33224)
    Expose state custom metrics information on SS UI (SPARK-33287)
    Add a new metric regarding number of rows later than watermark (SPARK-24634)
    Changes of behavior
    
    Please read the migration guides for Structured Streaming.
    
    Programming guides: Structured Streaming Programming Guide.
    
    MLlib
    Highlight
    
    LinearSVC blockify input vectors (SPARK-30642)
    LogisticRegression blockify input vectors (SPARK-30659)
    LinearRegression blockify input vectors (SPARK-30660)
    AFT blockify input vectors (SPARK-31656)
    Add support for association rules in ML (SPARK-19939)
    Add training summary for LinearSVCModel (SPARK-20249)
    Add summary to RandomForestClassificationModel (SPARK-23631)
    Add training summary to FMClassificationModel (SPARK-32140)
    Add summary to MultilayerPerceptronClassificationModel (SPARK-32449)
    Add FMClassifier to SparkR (SPARK-30820)
    Add SparkR LinearRegression wrapper (SPARK-30818)
    Add FMRegressor wrapper to SparkR (SPARK-30819)
    Add SparkR wrapper for vector_to_array (SPARK-33040)
    adaptively blockify instances - LinearSVC (SPARK-32907)
    make CrossValidator/TrainValidateSplit/OneVsRest Reader/Writer support Python backend estimator/evaluator (SPARK-33520)
    Improve performance of ML ALS recommendForAll by GEMV (SPARK-33518)
    Add UnivariateFeatureSelector (SPARK-34080)
    Other Notable Changes
    
    GMM compute summary and update distributions in one job (SPARK-31032)
    Remove ChiSqSelector dependency on mllib.ChiSqSelectorModel (SPARK-31077)
    Flatten the result dataframe of tests in testChiSquare (SPARK-31301)
    MinHash keyDistance optimization (SPARK-31436)
    KMeans optimization based on triangle-inequality (SPARK-31007)
    Add weight support in ClusteringEvaluator (SPARK-31734)
    Add getMetrics in Evaluators (SPARK-31768)
    Add instance weight support in LinearRegressionSummary (SPARK-31944)
    Add user-specified fold column to CrossValidator (SPARK-31777)
    ML params default value parity in feature and tuning (SPARK-32310)
    Fix double caching in KMeans/BiKMeans (SPARK-32676)
    aft transform optimization (SPARK-33111)
    FeatureHasher transform optimization (SPARK-32974)
    Add array_to_vector function for dataframe column (SPARK-33556)
    ML params default value parity in classification, regression, clustering and fpm (SPARK-32310)
    Summary.totalIterations greater than maxIters (SPARK-31925)
    tree models prediction optimization (SPARK-32298)
    Changes of behavior
    
    Please read the migration guides for MLlib.
    
    Programming guide: Machine Learning Library (MLlib) Guide.

注意:在进行建模时,我在 spark 2.4.5 上也注意到了这个问题,并且正在寻求调整阈值以提高 MLPC 模型对非常不平衡目标的性能。


推荐阅读