首页 > 解决方案 > 提高图像分类模型的准确性

问题描述

我正在尝试使用预先训练的 Inception V3 模型对两个不同的类进行图像分类。我有大约 1400 张图像的数据集,这些图像大致平衡。当我运行我的程序时,我得到的结果在前几个时期是关闭的。训练模型时这是正常的吗?

epochs =  175

batch_size = 64

#include_top = false to accomodate new classes 
base_model = keras.applications.InceptionV3(
        weights ='imagenet',
        include_top=False, 
        input_shape = (img_width,img_height,3))

#Classifier Model ontop of Convolutional Model
model_top = keras.models.Sequential()
model_top.add(keras.layers.GlobalAveragePooling2D(input_shape=base_model.output_shape[1:], data_format=None)),
model_top.add(keras.layers.Dense(350,activation='relu'))
model_top.add(keras.layers.Dropout(0.4))
model_top.add(keras.layers.Dense(1,activation = 'sigmoid'))
model = keras.models.Model(inputs = base_model.input, outputs = model_top(base_model.output))

#freeze the convolutional layers of InceptionV3
for layer in model.layers[:30]:
layer.trainable = False

#Compiling model using Adam Optimizer 
model.compile(optimizer = keras.optimizers.Adam(
                    lr=0.000001,
                    beta_1=0.9,
                    beta_2=0.999,
                    epsilon=1e-08),
                    loss='binary_crossentropy',
                    metrics=['accuracy'])

在此处输入图像描述

使用我当前的参数,在对一组分离的图像进行测试时,我只能获得 89% 的准确度和 0.3 的测试损失。我是否需要向我的模型添加更多层来提高这种准确性?

标签: pythonmachine-learningkerasneural-network

解决方案


您的代码有几个问题...

首先,您的构建model_top方式非常非传统(恕我直言,也非常混乱);在这种情况下,文档示例是您最好的朋友。因此,首先将您的model_top部分替换为:

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(350, activation='relu')(x)
x = Dropout(0.4)(x)
predictions = Dense(1, activation='sigmoid')(x)

# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)

请注意,我没有更改您选择的参数 - 您当然可以在密集层中尝试更多单元(文档中的示例使用 1024)...

其次,不清楚为什么选择只冻结 InceptionV3 的 30 层,它有不少于 311 层:

len(base_model.layers)
# 311

因此,也将此部分替换为

for layer in base_model.layers:
    layer.trainable = False

第三,你的学习率似乎太小了;Adam 优化器应该使用其默认参数开箱即用地工作得很好,所以我还建议将您的模型简单地编译为

model.compile(optimizer = keras.optimizers.Adam(),
                    loss='binary_crossentropy',
                    metrics=['accuracy'])

推荐阅读