neural-network - 使用 keras 绘制 Roc 曲线
问题描述
我有一个神经网络模型,我正在使用 KerasClassifier,然后使用 KFold 进行交叉验证。现在我在绘制 ROC 曲线时遇到了问题。我尝试了很少的代码,但大多数代码都给了我一个多标签未解释的错误。在我的神经网络产生准确性之前,我有以下代码。如果有人可以帮助我完成代码的后半部分,我将不胜感激。
import numpy as np
import pandas as pd
from keras.layers import Dense, Input
from keras.models import Model, Sequential
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import LabelEncoder, MinMaxScaler,StandardScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
seed = 7
np.random.seed(seed)
dataset = pd.read_csv('lukemia_2003.csv')
X_train = dataset.values[:,0:12600]
Y_train = dataset.values[:,12600]
scalar = MinMaxScaler()
scaled_data = scalar.fit_transform(X_train)
pca = PCA(n_components=10)
X_train_pca = pca.fit_transform(scaled_data)
encoder = LabelEncoder()
encoder.fit(Y_train)
encoded_Y = encoder.transform(Y_train)
dummy_Y = np_utils.to_categorical(encoded_Y)
hid_layer1 = 4
hid_layer2 = 4
output_layer = 4
def my_model():
encoded = Sequential()
encoded.add(Dense(hid_layer1, input_dim = 10, activation='tanh'))
encoded.add(Dense(hid_layer2, activation='tanh'))
encoded.add(Dense(output_layer, activation='softmax'))
encoded.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return encoded
result_mean_list = []
std_list = []
for i in range(30):
estimator = KerasClassifier(build_fn=my_model, epochs=1500, batch_size=5, verbose=2)
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(estimator, X_train_pca, dummy_Y, cv=kfold)
result_mean_list.append(round(results.mean()*100,2))
std_list.append(round(results.std()*100,2))
print ("Result mean: ", result_mean_list)
print ("Standard Deviation List: ", std_list)
这是数据集的链接。https://drive.google.com/open?id=15emI90-sPZMkHLuwRbNfTBli0h_S-PpM
解决方案
对于您的情况,由于您的目标是多类的,因此您不能使用 ROC 来评估分类器。在存在二元分类器的情况下,此链接显示如何绘制 ROC 曲线。
推荐阅读
- python-3.x - 如何使用 python 的 gspread wks.update_cells 更新谷歌电子表格中的单元格
- php - 使用 slim 框架在 php 中显示唯一的数组数据
- python - 带有保持额外价值的条件的python格式选项
- react-native - 如何在本机反应中在浮动按钮上获得阴影
- operating-system - 列出在完全专用的机器上运行程序所必需的四个步骤——一台只运行该程序的计算机
- javascript - JS ES6 IIFE + 符号和原型 - 添加到实例?
- c# - 为什么 MySQL 返回在 C# 代码中始终为 1,但在存储过程中测试时却不是?
- python - 即使使用 __init__.py 文件,导入也不起作用
- python - 如何使用 opencv 操作将 `plt.imsave` 替换为 `cmap` 选项设置为 `gray`?
- java - 如何检查在 BottomNavigationView 中膨胀了哪个菜单文件?