pyspark - 如何在 pyspark 中为 GBTClassifier 绘制 ROC 曲线?
问题描述
我正在尝试为梯度提升模型绘制 ROC 曲线。我遇到过这篇文章,但它似乎不适用于 GBTclassifier 模型。pyspark提取ROC曲线?
我在数据块中使用数据集,下面是我的代码。它给出了以下错误
AttributeError: 'PipelineModel' object has no attribute 'summary'
%fs ls databricks-datasets/adult/adult.data
from pyspark.sql.functions import *
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.feature import StringIndexer, OneHotEncoderEstimator, VectorAssembler, VectorSlicer
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator,MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
import pandas as pd
dataset = spark.table("adult")
# spliting the train and test data frames
splits = dataset.randomSplit([0.7, 0.3])
train_df = splits[0]
test_df = splits[1]
def predictions(train_df,
target_col,
):
"""
#Function attributes
dataframe - training df
target - target varibale in the model
"""
# one hot encoding and assembling
encoding_var = [i[0] for i in train_df.dtypes if (i[1]=='string') & (i[0]!=target_col)]
num_var = [i[0] for i in train_df.dtypes if ((i[1]=='int') | (i[1]=='double')) & (i[0]!=target_col)]
string_indexes = [StringIndexer(inputCol = c, outputCol = 'IDX_' + c, handleInvalid = 'keep') for c in encoding_var]
onehot_indexes = [OneHotEncoderEstimator(inputCols = ['IDX_' + c], outputCols = ['OHE_' + c]) for c in encoding_var]
label_indexes = StringIndexer(inputCol = target_col, outputCol = 'label', handleInvalid = 'keep')
assembler = VectorAssembler(inputCols = num_var + ['OHE_' + c for c in encoding_var], outputCol = "features")
gbt = GBTClassifier(featuresCol = 'features', labelCol = 'label',
maxDepth = 5,
maxBins = 45,
maxIter = 20)
pipe = Pipeline(stages = string_indexes + onehot_indexes + [assembler, label_indexes, gbt])
model = pipe.fit(train_df)
return model
gbt_model = predictions(train_df = train_df,
target_col = 'income')
import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
plt.plot([0, 1], [0, 1], 'r--')
plt.plot(gbt_model.summary.roc.select('FPR').collect(),
gbt_model.summary.roc.select('TPR').collect())
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.show()
解决方案
根据您的错误,请查看PipelineModel
此文档:
https ://spark.apache.org/docs/2.4.3/api/python/pyspark.ml.html#pyspark.ml.PipelineModel
summary
此类的对象上没有属性。相反,我相信您需要PipelineModel
单独访问阶段,例如gbt_model.stages[-1]
(这应该可以访问您的最后阶段 - GBTClassifier
。然后尝试使用那里的属性,例如:
gbt_model.stages[-1].summary
如果你GBTClassifier
有一个摘要,你会在那里找到它。希望这可以帮助。
推荐阅读
- asp.net-core - 所有选定复选框的 id 未传递给控制器操作
- flutter - 哪个颤振插件提供了多轴的图表?
- java - 如何摆脱打印行中的重复项
- python - 模型使用 350+ GB 的 RAM,但我很困惑为什么会发生这种情况以及我应该如何减少这个数字?
- java - 如何在 Spring Boot REST 中处理未知数量的 RequestParam?
- node.js - 重新启动应用程序 Heroku Node 服务器时 Swift 给 H13 “连接关闭而没有响应”
- c# - 根据方括号中的项目数将字符串拆分为多个字符串C#
- julia - 尝试编写一个 softmax 和 NNLib softmax 给出意外的输出
- c# - DotPulsar:如何设置消费者接收队列大小?
- java - 缺少要求 osgi.wiring.package;(&(osgi.wiring.package=org.osgi.framework)(version>=1.10.0)(!(version>=2.0.0)))