首页 > 解决方案 > 如何找到决策树分类的特征名称?

问题描述

我正在尝试查找决策树的特征信息。更具体地说,如果 183 出现在我的树形可视化中,我希望能够分辨出它是什么特征。我试过 dtModel.getInputCol() 但收到以下错误。

AttributeError: 'DecisionTreeClassificationModel' object has no attribute 'getInputCol'

这是我当前的代码:

from pyspark.ml.classification import DecisionTreeClassifier

# Create initial Decision Tree Model
dt = DecisionTreeClassifier(labelCol="label", featuresCol="features", maxDepth=3)

# Train model with Training Data
dtModel = dt.fit(trainingData)
display(dtModel)

如果您可以提供帮助或需要更多信息,请告诉我。谢谢你。

标签: pythonapache-sparkpysparkdatabricksapache-spark-mllib

解决方案


请参阅此示例取自Spark 文档(我尝试使名称与您的代码一致,尤其是featuresCol="features")。

我假设您有一些这样的代码(在您在问题中发布的代码之前):

featureIndexer = VectorIndexer(inputCol="inputFeatures", outputCol="features", maxCategories=4).fit(data)

在此步骤之后,您拥有"features"索引功能,然后您将其提供给DecisionTreeClassifier(就像您发布的代码一样):

# Train a DecisionTree model.
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="features")

你要找的是inputFeatures上面的,这是被索引之前的原始特征。如果要打印它,只需执行以下操作:

sc.parallelize(inputFeatures, 1).saveAsTextFile("absolute_path") 

推荐阅读