python - keras:评估多类 CNN 的 ROC AUC
问题描述
我正在使用keras
Sequential() API 为 5 类问题构建我的 CNN 模型。由于准确度不是多类问题的良好指标,因此我必须评估其他指标度量来评估我的模型。目前,我使用sklearn
'sconfusion_matrix
和classification_report
,但我想研究更多指标,所以我决定评估 ROC AUC,但我不确定 keras 是如何完成的,我应该对我的代码做哪些修改等。
目前,这就是我构建模型的方式:
model = Sequential()
activ = 'relu'
model.add(Conv2D(32, (1, 3), strides=(1, 1), padding='same', activation=activ, input_shape=(1, 100, 4)))
model.add(Conv2D(32, (1, 3), strides=(1, 1), padding='same', activation=activ ))
model.add(MaxPooling2D(pool_size=(1, 2) ))
model.add(Conv2D(64, (1, 3), strides=(1, 1), padding='same', activation=activ))
model.add(Conv2D(64, (1, 3), strides=(1, 1), padding='same', activation=activ))
model.add(MaxPooling2D(pool_size=(1, 2)))
model.add(Flatten())
A = model.output_shape
model.add(Dense(int(A[1] * 1/4.), activation=activ))
model.add(Dense(5, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=64, shuffle=True, callbacks=callbacks,
validation_split=0.2)
测试评价:
Pred = model.predict(x_test, batch_size=32)
Pred_Label = np.argmax(Pred, axis=1)
test_acc = accuracy_score(y_test, Pred_Label)
ConfusionM = confusion_matrix(list(y_test), Pred_Label, labels=[0, 1, 2, 3, 4])
class_report = classification_report(list(y_test), Pred_Label, labels=[0, 1, 2, 3, 4])
要获得结果,如下所示:
Confusion Matrix:
[[ 2514 1040 2584 6690 1773]
[ 208 359 37 668 126]
[ 1445 1156 1172 3438 1106]
[ 3158 2014 2993 10185 1951]
[ 154 77 29 493 151]]
Classification Report:
precision recall f1-score support
Class:0 0.34 0.17 0.23 14601
Class:1 0.08 0.26 0.12 1398
Class:2 0.17 0.14 0.15 8317
Class:3 0.47 0.50 0.49 20301
Class:4 0.03 0.17 0.05 904
accuracy 0.32 45521
macro avg 0.22 0.25 0.21 45521
weighted avg 0.35 0.32 0.32 45521
如何将 ROC AUC 添加到我的模型指标?
解决方案
绘制多类分类器的 ROC 曲线的另一种方法如下所示。让我们来看一个玩具问题,CIFAR10,一个多类数据集,由 10 个不同的类组成。
import tensorflow as tf
import numpy as np
(x_train, y_train), (_, _) = tf.keras.datasets.cifar10.load_data()
# train set / data
x_train = x_train.astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train , num_classes=10)
print(x_train.shape, y_train.shape)
# (50000, 32, 32, 3) (50000, 10)
该模型
input = tf.keras.Input(shape=(32,32,3))
efnet = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3),
strides=(1, 1), input_shape=(32, 32, 3),
activation='relu')(input)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(efnet)
# Finally, we add a classification layer.
output = tf.keras.layers.Dense(10, activation='softmax')(gap)
# bind all
func_model = tf.keras.Model(input, output)
编译并运行
func_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# fit
func_model.fit(x_train, y_train, batch_size=128, epochs=20, verbose = 1)
获取预测标签和真实标签
ypred = func_model.predict(x_train)
ypred = ypred.argmax(axis=-1)
ypred
array([7, 9, 7, ..., 9, 1, 1])
ytrain = y_train.argmax(axis=-1)
ytrain
array([6, 9, 9, ..., 9, 1, 1])
绘制单个目标的 ROC 曲线。
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import roc_curve, auc, roc_auc_score
target= ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# set plot figure size
fig, c_ax = plt.subplots(1,1, figsize = (12, 8))
# function for scoring roc auc score for multi-class
def multiclass_roc_auc_score(y_test, y_pred, average="macro"):
lb = LabelBinarizer()
lb.fit(y_test)
y_test = lb.transform(y_test)
y_pred = lb.transform(y_pred)
for (idx, c_label) in enumerate(target):
fpr, tpr, thresholds = roc_curve(y_test[:,idx].astype(int), y_pred[:,idx])
c_ax.plot(fpr, tpr, label = '%s (AUC:%0.2f)' % (c_label, auc(fpr, tpr)))
c_ax.plot(fpr, fpr, 'b-', label = 'Random Guessing')
return roc_auc_score(y_test, y_pred, average=average)
print('ROC AUC score:', multiclass_roc_auc_score(ytrain, ypred))
c_ax.legend()
c_ax.set_xlabel('False Positive Rate')
c_ax.set_ylabel('True Positive Rate')
plt.show()
ROC AUC score: 0.6868888888888889
推荐阅读
- python - 如何在 Tensorflow 对象检测 API 中使用两个模型
- c# - 如何使用反射来序列化和反序列化(通用方式)。不使用 dll
- google-bigquery - 如何删除 BQ 中的空行
- android - 用于数字液晶显示器的 Firebase MLkit 文本识别
- php - PHP preg_match_all not working on large data
- java - 子类到超类的 JSON 序列化
- javascript - 从文本编辑器禁用复制粘贴
- java - 从刚刚发送的电子邮件中获取 id
- visual-studio-code - 如何在 VS Code 中将选定的代码段向左移动?
- arrays - Bash 数组变量转换为单个空格元素