python - Deep learning: Training set tends to be good and Validation set is bad
问题描述
I am facing to a problem for which I have difficulties to understand why I have such behaviour.
I am trying to use a pre-trained resnet 50 (keras) model for a binary image classification, I also built a simple cnn. I have about 8k balanced RGB images of size 200x200 and I divided this set into three sub-sets (train 70%, validation 15%, test 15%).
I built a generator to feed data to my models based on keras.utils.Sequence
.
The problem that I have is my models tends to learn on the training set but on validation set I have poor results on pre-trained resnet50 and on simple cnn. I tried several things to solve this problem but Not improvement at all.
- With and without Data augmentation on training set (rotation)
- Images are normalised between [0,1]
- With and without Regularizers
- Variation of the learning rate
This is an example of results obtained:
Epoch 1/200
716/716 [==============================] - 320s 447ms/step - loss: 8.6096 - acc: 0.4728 - val_loss: 8.6140 - val_acc: 0.5335
Epoch 00001: val_loss improved from inf to 8.61396, saving model to ../models_saved/resnet_adam_best.h5
Epoch 2/200
716/716 [==============================] - 287s 401ms/step - loss: 8.1217 - acc: 0.5906 - val_loss: 10.9314 - val_acc: 0.4632
Epoch 00002: val_loss did not improve from 8.61396
Epoch 3/200
716/716 [==============================] - 249s 348ms/step - loss: 7.5357 - acc: 0.6695 - val_loss: 11.1432 - val_acc: 0.4657
Epoch 00003: val_loss did not improve from 8.61396
Epoch 4/200
716/716 [==============================] - 284s 397ms/step - loss: 7.5092 - acc: 0.6828 - val_loss: 10.0665 - val_acc: 0.5351
Epoch 00004: val_loss did not improve from 8.61396
Epoch 5/200
716/716 [==============================] - 261s 365ms/step - loss: 7.0679 - acc: 0.7102 - val_loss: 4.2205 - val_acc: 0.5351
Epoch 00005: val_loss improved from 8.61396 to 4.22050, saving model to ../models_saved/resnet_adam_best.h5
Epoch 6/200
716/716 [==============================] - 285s 398ms/step - loss: 6.9945 - acc: 0.7161 - val_loss: 10.2276 - val_acc: 0.5335
....
This is classes used to load data into my models.
class DataGenerator(keras.utils.Sequence):
def __init__(self, inputs,
labels, img_size,
input_shape,
batch_size, num_classes,
validation=False):
self.inputs = inputs
self.labels = labels
self.img_size = img_size
self.input_shape = input_shape
self.batch_size = batch_size
self.num_classes = num_classes
self.validation = validation
self.indexes = np.arange(len(self.inputs))
self.inc = 0
def __getitem__(self, index):
"""Generate one batch of data
Parameters
----------
index :the index from which batch will be taken
Returns
-------
out : a tuple that contains (inputs and labels associated)
"""
batch_inputs = np.zeros((self.batch_size, *self.input_shape))
batch_labels = np.zeros((self.batch_size, self.num_classes))
# Generate data
for i in range(self.batch_size):
# choose random index in features
if self.validation:
index = self.indexes[self.inc]
self.inc += 1
if self.inc == len(self.inputs):
self.inc = 0
else:
index = random.randint(0, len(self.inputs) - 1)
batch_inputs[i] = self.rgb_processing(self.inputs[index])
batch_labels[i] = to_categorical(self.labels[index], num_classes=self.num_classes)
return batch_inputs, batch_labels
def __len__(self):
"""Denotes the number of batches per epoch
Returns
-------
out : number of batches per epochs
"""
return int(np.floor(len(self.inputs) / self.batch_size))
def rgb_processing(self, path):
img = load_img(path)
rgb = img.get_rgb_array()
if not self.validation:
if random.choice([True, False]):
rgb = random_rotation(rgb)
return rgb/np.max(rgb)
class Models:
def __init__(self, input_shape, classes):
self.input_shape = input_shape
self.classes = classes
pass
def simpleCNN(self, optimizer):
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=self.input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(len(self.classes), activation='softmax'))
model.compile(loss=keras.losses.binary_crossentropy,
optimizer=optimizer,
metrics=['accuracy'])
return model
def resnet50(self, optimizer):
model = keras.applications.resnet50.ResNet50(include_top=False,
input_shape=self.input_shape,
weights='imagenet')
model.summary()
model.layers.pop()
model.summary()
for layer in model.layers:
layer.trainable = False
output = Flatten()(model.output)
#I also tried to add dropout layers here with batch normalization but it does not change results
output = Dense(len(self.classes), activation='softmax')(output)
finetuned_model = Model(inputs=model.input,
outputs=output)
finetuned_model.compile(optimizer=optimizer,
loss=keras.losses.binary_crossentropy,
metrics=['accuracy'])
return finetuned_model
This is how these functions are called:
train_batches = DataGenerator(inputs=train.X.values,
labels=train.y.values,
img_size=img_size,
input_shape=input_shape,
batch_size=batch_size,
num_classes=len(CLASSES))
validate_batches = DataGenerator(inputs=validate.X.values,
labels=validate.y.values,
img_size=img_size,
input_shape=input_shape,
batch_size=batch_size,
num_classes=len(CLASSES),
validation=True)
if model_name == "cnn":
model = models.simpleCNN(optimizer=Adam(lr=0.0001))
elif model_name == "resnet":
model = models.resnet50(optimizer=Adam(lr=0.0001))
early_stopping = EarlyStopping(patience=15)
checkpointer = ModelCheckpoint(output_name + '_best.h5', verbose=1, save_best_only=True)
history = model.fit_generator(train_batches, steps_per_epoch=num_train_steps, epochs=epochs,
callbacks=[early_stopping, checkpointer], validation_data=validate_batches,
validation_steps=num_valid_steps)
解决方案
I finally found the principal element that causes this over-fitting. Since I use a pre-trained model. I was set layers as non-trainable. Thus I tried to put them as trainable and It seems that it solves the problem.
for layer in model.layers:
layer.trainable = False
My hypothesis is that my images are too far away from data used to train the model.
I also added some dropouts and batch normalization at the end of the resnet model.
推荐阅读
- reactjs - 使用 Azure AD MSAL React 响应进行 Express Js 后端 API 令牌验证
- android - Android:资源联动失败,找不到文件
- python - Django forms.ModelForm POST 更新前更改对象
- php - 如何使用 php 中的 get 方法获取价值
- php - 当我写 @foreach($products as $product) 时,它会破坏其中的所有 html 并且没有任何东西出现
- laravel-8 - Laravel 的 php artisan migrate --path='path' 有效,但在指定 php artisan migrate 时不起作用
- html - 选择多个下拉菜单 - 我应该在哪里放置 name="blabla" 或如何绑定数据
- javascript - 有没有办法改变功能组件对象中多个项目的值?反应
- android - Android 和 iOS 上的 Unity ScrollView 性能问题
- android - 在flutter中的statefullWidget中没有调用setState()?