keras - 在 keras CNN 中获得精确度、召回率、灵敏度和特异性
问题描述
我创建了一个对图像进行二进制分类的 CNN。CNN 如下图所示:
def neural_network():
classifier = Sequential()
# Adding a first convolutional layer
classifier.add(Convolution2D(48, 3, input_shape = (320, 320, 3), activation = 'relu'))
classifier.add(MaxPooling2D())
# Adding a second convolutional layer
classifier.add(Convolution2D(48, 3, activation = 'relu'))
classifier.add(MaxPooling2D())
#Flattening
classifier.add(Flatten())
#Full connected
classifier.add(Dense(256, activation = 'relu'))
#Full connected
classifier.add(Dense(1, activation = 'sigmoid'))
# Compiling the CNN
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
classifier.summary()
train_datagen = ImageDataGenerator(rescale = 1./255,
horizontal_flip = True,
vertical_flip=True,
brightness_range=[0.5, 1.5])
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory('/content/drive/My Drive/data_sep/train',
target_size = (320, 320),
batch_size = 32,
class_mode = 'binary')
test_set = test_datagen.flow_from_directory('/content/drive/My Drive/data_sep/validate',
target_size = (320, 320),
batch_size = 32,
class_mode = 'binary')
es = EarlyStopping(
monitor="val_accuracy",
patience=15,
mode="max",
baseline=None,
restore_best_weights=True,
)
filepath = "/content/drive/My Drive/data_sep/weightsbestval.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
history = classifier.fit(training_set,
epochs = 50,
validation_data = test_set,
callbacks= callbacks_list
)
best_score = max(history.history['val_accuracy'])
return best_score
文件夹中的图像按以下方式组织:
-train
-healthy
-patient
-validation
-healthy
-patient
有没有办法从这段代码中计算指标精度、召回率、灵敏度和特异性,或者至少是真阳性、真阴性、假阳性和假阴性?
解决方案
from sklearn.metrics import classification_report
test_set = test_datagen.flow_from_directory('/content/drive/My Drive/data_sep/validate',
target_size = (320, 320),
batch_size = 32,
class_mode = 'binary')
predictions = model.predict_generator(
test_set,
steps = np.math.ceil(test_set.samples / test_set.batch_size),
)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = test_set.classes
class_labels = list(test_set.class_indices.keys())
report = classification_report(true_classes, predicted_classes, target_names=class_labels)
accuracy = metrics.accuracy_score(true_classes, predicted_classes)
&如果你这样做print(report)
,它会打印一切
如果你的整个数据文件不能被你的批量大小整除,那么使用
test_set = test_datagen.flow_from_directory('/content/drive/My Drive/data_sep/validate',
target_size = (320, 320),
batch_size = 1,
class_mode = 'binary')
推荐阅读
- php - 如何在 php 中获取 Sql Not Null 字段?
- vb.net - 如何在类中添加点击事件 vb.net
- asp.net-mvc - 如何在部分视图asp.net mvc中编写客户端自定义java脚本进行验证
- ajax - 更改选择框值不能应用于使用 ajax 和 jquery 的标签
- jquery - 如何在图像的宽度上制作花式框标题?
- python - 如何使用 python 检索令牌
- go - 在括号中初始化 Go 结构有什么作用?
- java - 如何从 Spring Boot maven 项目中删除或隐藏目标文件?
- c# - CardView中的Xamarin Recyclerview按钮,无法获取位置
- maven - 如何将jmeter聚合图保存为maven目标文件夹中的文件夹中的png