python - 如何使用 KFold 交叉验证输出作为 CNN 输入进行图像处理?
问题描述
我正在尝试使用卷积神经网络 (CNN) 进行图像分类。我想使用 KFold 交叉验证进行数据训练和测试。我是新手,我真的不明白该怎么做。
我已经在单独的代码中尝试过 KFold Cross Validation 和 CNN。而且我不知道如何结合它。
我使用具有 3 个类的 iris_data.csv 作为输入示例。
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.svm import SVR
dataset = pd.read_csv('iris_data.csv')
x = dataset.iloc[:,0:3]
y = dataset.iloc[:, 4]
scaler = MinMaxScaler(feature_range=(0, 1))
x = scaler.fit_transform(x)
cv = KFold(n_splits=10, shuffle=False)
for train_index, test_index in cv.split(x):
print("Train Index: ", train_index, "\n")
print("Test Index: ", test_index)
x_train, x_test, y_train, y_test = x[train_index], x[test_index], y[train_index], y[test_index]
这里是 CNN 代码示例。
import numpy as np
import tensorflow as tf
from keras.models import Model
from keras.layers import Input, Activation, Dense, Conv2D, MaxPooling2D, Flatten
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
# Images Dimensions
img_width, img_height = 200, 200
# Data Path
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
# Parameters
nb_train_samples = 100
nb_validation_samples = 50
epochs = 50
batch_size = 10
# TensorBoard Callbacks
callbacks = TensorBoard(log_dir='./Graph')
# Training Data Augmentation
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# Rescale Testing Data
test_datagen = ImageDataGenerator(rescale=1. / 255)
# Train Data Generator
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
# Testing Data Generator
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
# Feature Extraction Layer KorNet
inputs = Input(shape=(img_width, img_height, 3))
conv_layer = Conv2D(16, (5, 5), strides=(3,3), activation='relu')(inputs)
conv_layer = MaxPooling2D((2, 2))(conv_layer)
conv_layer = Conv2D(32, (5, 5), strides=(3,3), activation='relu')(conv_layer)
conv_layer = MaxPooling2D((2, 2))(conv_layer)
# Flatten Layer
flatten = Flatten()(conv_layer)
# Fully Connected Layer
fc_layer = Dense(32, activation='relu')(flatten)
outputs = Dense(3, activation='softmax')(fc_layer)
model = Model(inputs=inputs, outputs=outputs)
# Adam Optimizer and Cross Entropy Loss
adam = Adam(lr=0.0001)
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# Print Model Summary
print(model.summary())
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size,
callbacks=[callbacks])
model.save('./models/model.h5')
model.save_weights('./models/weights.h5')
我希望将 KFold 交叉验证的结果用作 CNN 的训练和测试数据。
解决方案
做这样的事情
from keras.models import Sequential
from sklearn.model_selection import KFold
import numpy
dataset = numpy.loadtxt("iris_data.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:3]
Y = dataset[:,4]
# define 10-fold cross validation test harness
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y):
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
.
.
.
# Compile model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
# evaluate the model
scores = model.evaluate(X[test], Y[test], verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))
看到这个https://machinelearningmastery.com/evaluate-performance-deep-learning-models-keras/
推荐阅读
- coldfusion - 如何在 ColdFusion 中保持选择下拉选项
- c# - 在编译 Xamarin.Forms 项目时修复“VerifyVersionsTask”错误
- pdflib - PDFlib - 控制文本颜色、文本背景和文本笔划的背景和不透明度
- asp.net-mvc-5 - ELMAH_Error 表中记录的错误未显示在 elmah.axd 页面上
- php - 在 Windows 中使用 Docker 时的 Wordpress 和 DB 文件
- typescript - 如何将 RethinkDB 集成到 Angular 7 Web 应用程序中?
- hive - 使用特殊字符插入 Hive 表内容 - 制表符空格和换行符
- kubernetes-helm - Helm 未显示版本/无法删除旧版本
- node.js - Express JS 在不关闭服务器的情况下清除内存的方法?
- r - ggplot2:添加第二个不同比例的 y 轴