python - 为 DecisionTreeClassifier 绘制多类 ROC 曲线
问题描述
我试图用文档中提供的 svm.SVC 以外的分类器绘制 ROC 曲线。我的代码适用于 svm.SVC;但是,在我切换到 KNeighborsClassifier、MultinomialNB 和 DecisionTreeClassifier 后,系统一直告诉我check_consistent_length(y_true, y_score)
,Found input variables with inconsistent numbers of samples: [26632, 53264]
我的 CSV 文件看起来像这样
这是我的代码
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
import sys
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeClassifier
# Import some data to play with
df = pd.read_csv("E:\\autodesk\\Hourly and weather categorized2.csv")
X =df[['TTI','Max TemperatureF','Mean TemperatureF','Min TemperatureF',' Min Humidity']].values
y = df['TTI_Category'].as_matrix()
y=y.reshape(-1,1)
# Binarize the output
y = label_binarize(y, classes=['Good','Bad'])
n_classes = y.shape[1]
# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
random_state=0)
# Learn to predict each class against the other
classifier = OneVsRestClassifier(DecisionTreeClassifier(random_state=0))
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
plt.figure()
lw = 1
plt.plot(fpr[0], tpr[0], color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[0])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
我怀疑错误发生在这一行fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
,但我是这个 ROC 曲线的初学者,所以有人可以指导我完成这个追溯。非常感谢您的时间和帮助。这是我提出的另一个关于 ROC 曲线
的问题顺便说一下,这里是整个追溯。希望我的解释足够清楚。`
Traceback (most recent call last):
File "<ipython-input-1-16eb0db9d4d9>", line 1, in <module>
runfile('C:/Users/Think/Desktop/Python Practice/ROC with decision tree.py', wdir='C:/Users/Think/Desktop/Python Practice')
File "C:\Users\Think\Anaconda2\lib\site-packages\spyder\utils\site\sitecustomize.py", line 880, in runfile
execfile(filename, namespace)
File "C:\Users\Think\Anaconda2\lib\site-packages\spyder\utils\site\sitecustomize.py", line 87, in execfile
exec(compile(scripttext, filename, 'exec'), glob, loc)
File "C:/Users/Think/Desktop/Python Practice/ROC with decision tree.py", line 47, in <module>
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
File "C:\Users\Think\Anaconda2\lib\site-packages\sklearn\metrics\ranking.py", line 510, in roc_curve
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight)
File "C:\Users\Think\Anaconda2\lib\site-packages\sklearn\metrics\ranking.py", line 302, in _binary_clf_curve
check_consistent_length(y_true, y_score)
File "C:\Users\Think\Anaconda2\lib\site-packages\sklearn\utils\validation.py", line 173, in check_consistent_length
" samples: %r" % [int(l) for l in lengths])
ValueError: Found input variables with inconsistent numbers of samples: [26632, 53264]
解决方案
您需要使用以下predict_proba
功能DecisionTreeClassifier
:
例子:
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.multiclass import OneVsRestClassifier
from sklearn.tree import DecisionTreeClassifier
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0)
classifier = OneVsRestClassifier(DecisionTreeClassifier(random_state=0))
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
colors = cycle(['blue', 'red', 'green'])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic for multi-class data')
plt.legend(loc="lower right")
plt.show()
推荐阅读
- makefile - 我可以在生成文件中添加动态地址吗?
- arangodb - 如何替换 arangodb 中的列表元素
- java - 在命令中使用带有空格的 Runtime.exec 时“无法运行程序”
- javascript - 无法在 MacOS 上打开 Intellij
- react-native - 付款完成后获取信用卡号
- c# - 如何在运行时更改 url 徽标
- regex - 正则表达式无法对两个匹配项使用相同的字符
- html - CSS 元素出现在当前行下方。怎么修?
- xcode - 如何消除有关 MobileCoreServices 和 AssetsLibrary 的 Xcode 11.4 警告?
- php - 双击单个链接