首页 > 解决方案 > Spark MlLib (Java) 中的分类和数值特征

问题描述

我正在使用 Java 中的 Apache Spark MlLib 2.11 版。我需要将分类和数字特征(字符串和数字)都传递给 RandomForestClassifier。

对于这种情况,最好的 API 是什么?一个例子会很有帮助。

编辑

我尝试使用 VectorIndexer,但它只接受数字,我不明白如何将 OneHotEncoder 集成到它。另外,我不清楚如何区分哪些特征是分类的,哪些是数字的。我需要在哪里设置所有可能的类别?

这是我尝试的一些代码:

StructType schema = DataTypes.createStructType(new StructField[] {
        new StructField("label", DataTypes.StringType, false, Metadata.empty()),
        new StructField("features", new ArrayType(DataTypes.StringType, false), false,
                Metadata.empty()),
});

JavaRDD<Row> rowRDD = trainingData.map(record -> {
    List<String> values = new ArrayList<>();
    for (String field : fields) {
        values.add(record.get(field));
    }
    return RowFactory.create(record.get(Constants.GROUND_TRUTH), values.toArray(new String[0]));
});

Dataset<Row> trainingDataDataframe = spark.createDataFrame(rowRDD, schema);

StringIndexerModel labelIndexer = new StringIndexer()
        .setInputCol("label")
        .setOutputCol("indexedLabel")
        .fit(trainingDataDataframe);

OneHotEncoder encoder = new OneHotEncoder()
        .setInputCol("features")
        .setOutputCol("featuresVec");
Dataset<Row> encoded = encoder.transform(trainingDataDataframe);

VectorIndexerModel featureIndexer = new VectorIndexer()
        .setInputCol("featuresVec")
        .setOutputCol("indexedFeatures")
        .setMaxCategories(maxCategories)
        .fit(encoded);

StringIndexerModel featureIndexer = new StringIndexer()
        .setInputCol("features")
        .setOutputCol("indexedFeatures")
        .fit(encoded);

RandomForestClassifier rf = new RandomForestClassifier();
        .setNumTrees(numTrees);
        .setFeatureSubsetStrategy(featureSubsetStrategy);
        .setImpurity(impurity);
        .setMaxDepth(maxDepth);
        .setMaxBins(maxBins);
        .setSeed(seed)
        .setLabelCol("indexedLabel")
        .setFeaturesCol("indexedFeatures");

IndexToString labelConverter = new IndexToString()
        .setInputCol("prediction")
        .setOutputCol("predictedLabel")
        .setLabels(labelIndexer.labels());

Pipeline pipeline = new Pipeline()
        .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter});

PipelineModel model = pipeline.fit(encoded);

标签: apache-sparkapache-spark-mllib

解决方案


与决策树一样,随机森林不需要 One Hot 编码来管理分类特征,它是少数可以原生管理分类特征的技术之一(也就是说,无需转换为二进制特征,这就是热编码)。

同时处理连续和分类特征的最简单方法是 maxCategories正确设置参数。当您训练您的森林时,将计算每个特征的不同值,并且训练数据中具有小于maxCategories不同值的列将被视为分类。

您可以通过打印树/森林来检查该特征是否是分类的,使用 toDebugString. 如果它是分类的,你会看到类似的东西 if feature0 in {0,1,2}而不是通常的 <=.


推荐阅读