keras - 使用 Keras 过拟合的 InceptionV3 迁移学习太快了
问题描述
我在 Keras 上使用预训练的 InceptionV3 重新训练模型以进行二进制图像分类(数据标记为 0 和 1)。
我用从未见过的数据在我的 k 折验证中达到了大约 65% 的准确率,但问题是模型很快就会过拟合。我需要提高这个平均准确率,我想这与这个过度拟合问题有关。
这是代码。数据集和标签变量是 Numpy 数组。
dataset = joblib.load(path_to_dataset)
labels = joblib.load(path_to_labels)
le = LabelEncoder()
labels = le.fit_transform(labels)
labels = to_categorical(labels, 2)
X_train, X_test, y_train, y_test = sk.train_test_split(dataset, labels, test_size=0.2)
X_train, X_val, y_train, y_val = sk.train_test_split(X_train, y_train, test_size=0.25) # 0.25 x 0.8 = 0.2
X_train = np.array(X_train)
y_train = np.array(y_train)
X_val = np.array(X_val)
y_val = np.array(y_val)
X_test = np.array(X_test)
y_test = np.array(y_test)
aug = ImageDataGenerator(
rotation_range=20,
zoom_range=0.15,
horizontal_flip=True,
fill_mode="nearest")
pre_trained_model = InceptionV3(input_shape = (299, 299, 3),
include_top = False,
weights = 'imagenet')
for layer in pre_trained_model.layers:
layer.trainable = False
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation = 'relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(2, activation = 'softmax')(x) #already tried with sigmoid activation, same behavior
model = Model(pre_trained_model.input, x)
model.compile(optimizer = RMSprop(lr = 0.0001),
loss = 'binary_crossentropy',
metrics = ['accuracy']) #Already tried with Adam optimizer, same behavior
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=100)
mc = ModelCheckpoint('best_model_inception_rmsprop.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
history = model.fit(x=aug.flow(X_train, y_train, batch_size=32),
validation_data = (X_val, y_val),
epochs = 100,
callbacks=[es, mc])
训练数据集有 2181 张图像,验证有 727 张图像。
有什么不对劲,但我不知道是什么...
有什么想法可以改善它吗?
解决方案
避免过度拟合的一种方法是使用大量数据。发生过度拟合的主要原因是你有一个小数据集,你试图从中学习。该算法将更好地控制这个小数据集,并确保它完全满足所有数据点。但是如果你有大量的数据点,那么算法就被迫进行泛化,并提出一个适合大多数点的好模型。建议:
- 使用大量数据。
- 如果您有少量数据样本,请使用较少深度的网络。
- 如果 2nd 满足,那么不要使用大量的 epochs - 使用许多 epochs 引导有点迫使你的模型学习它,你的模型会很好地学习它,但不能泛化。
推荐阅读
- python - Python Pygame 蛇游戏项目混乱
- c# - C# Mono.Cecil 注入的 IL 代码未执行
- ansible - ansible 在 src 和 dest 之间的同一主机上同步
- javascript - 如何从生成的模式中获取简单的 graphql 变异查询?
- python - Python OpenCV videocapture 不从源捕获视频
- dataframe - Julia 日期作为列名
- javascript - Sapper firebase 托管损坏的 CSS 和其他资产链接
- linux - 基于 bash 脚本中的浮点矩阵的循环 (Linux/Ubuntu)
- python - 将字符串变量分解为多列
- elasticsearch - 如何使用 Elastic 对嵌套对象进行子聚合?