首页 > 解决方案 > 在 PySpark 管道中使用交叉验证进行过采样

问题描述

我正在研究 PySpark 二进制分类管道,我想在其中使用过采样阶段执行交叉验证(我的数据集不平衡)。问题是过采样阶段也在测试数据集上执行。

管道:

pipeline=Pipeline(stages=[cast_and_fill_na, smote, vec_assembler, rf])

smote是转换测试数据集时我想跳过的阶段。

我查看了 spark 文档和源代码,没有办法跳过 PipelineModel 中的一个阶段。我的解决方案是覆盖_transform原始类的方法以跳过过采样阶段。在我的源代码中安装管道时,这可以正常工作。我用这个:

pipeline_model.__class__ = CustomPipelineModel

CustomPipelineModel是一个继承pyspark.ml.PipelineModel并覆盖该_transform方法的类。

但是由于 CrossValidator 使用 PipelineModel 类的原始实现,我不能使用我的自定义方法。

evaluator = BinaryClassificationEvaluator(labelCol=target)    
crossval = CrossValidator(estimator=pipeline,
                                      estimatorParamMaps=paramGrid,
                                      evaluator=evaluator,
                                      numFolds=10,
                                      parallelism=1)
cvModel = crossval.fit(train_set)

使用 Cross Validator 时跳过过采样阶段的最佳方法是什么?

我也开始研究考虑覆盖它的_fit方法的源代码pyspark.ml.tuning.CrossValidator......第二种解决方案是对训练数据集执行过采样,但这会在交叉验证过程中将偏差引入模型。

标签: pythonpysparkcross-validationoversamplingsmote

解决方案


我想出了一个解决这个问题的方法。在我的 SMOTEOversmapler 类(smote 阶段是它的一个实例)中,我添加了一个 atteribute namede skip_transform,它在实例化 SMOTEOversmapler 对象时设置为 None。在该_transform方法中,我将此属性设置为 True。将跳过下一次调用_transform(处于测试阶段)。这是一个代码片段。

def __init__(self, ...):
    self.skip_transfrom = None
def _transform(self, df):
    if self.skip_transform:
         retrun df
    else:
         #Execute oversampling
         self.skip_transform = True

推荐阅读