python - 提高图像分类模型的准确性
问题描述
我正在尝试使用预先训练的 Inception V3 模型对两个不同的类进行图像分类。我有大约 1400 张图像的数据集,这些图像大致平衡。当我运行我的程序时,我得到的结果在前几个时期是关闭的。训练模型时这是正常的吗?
epochs = 175
batch_size = 64
#include_top = false to accomodate new classes
base_model = keras.applications.InceptionV3(
weights ='imagenet',
include_top=False,
input_shape = (img_width,img_height,3))
#Classifier Model ontop of Convolutional Model
model_top = keras.models.Sequential()
model_top.add(keras.layers.GlobalAveragePooling2D(input_shape=base_model.output_shape[1:], data_format=None)),
model_top.add(keras.layers.Dense(350,activation='relu'))
model_top.add(keras.layers.Dropout(0.4))
model_top.add(keras.layers.Dense(1,activation = 'sigmoid'))
model = keras.models.Model(inputs = base_model.input, outputs = model_top(base_model.output))
#freeze the convolutional layers of InceptionV3
for layer in model.layers[:30]:
layer.trainable = False
#Compiling model using Adam Optimizer
model.compile(optimizer = keras.optimizers.Adam(
lr=0.000001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-08),
loss='binary_crossentropy',
metrics=['accuracy'])
使用我当前的参数,在对一组分离的图像进行测试时,我只能获得 89% 的准确度和 0.3 的测试损失。我是否需要向我的模型添加更多层来提高这种准确性?
解决方案
您的代码有几个问题...
首先,您的构建model_top
方式非常非传统(恕我直言,也非常混乱);在这种情况下,文档示例是您最好的朋友。因此,首先将您的model_top
部分替换为:
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(350, activation='relu')(x)
x = Dropout(0.4)(x)
predictions = Dense(1, activation='sigmoid')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
请注意,我没有更改您选择的参数 - 您当然可以在密集层中尝试更多单元(文档中的示例使用 1024)...
其次,不清楚为什么选择只冻结 InceptionV3 的 30 层,它有不少于 311 层:
len(base_model.layers)
# 311
因此,也将此部分替换为
for layer in base_model.layers:
layer.trainable = False
第三,你的学习率似乎太小了;Adam 优化器应该使用其默认参数开箱即用地工作得很好,所以我还建议将您的模型简单地编译为
model.compile(optimizer = keras.optimizers.Adam(),
loss='binary_crossentropy',
metrics=['accuracy'])
推荐阅读
- flutter - 在 Google Play 管理中心更新应用时如何更改版本号和内部版本号?
- arrays - 如何从数组内部的对象访问键:值对?
- java - 如何修复 checkmarx 错误以清理有效负载
- quarkus - Quarkus 中的 Cookie 管理
- javascript - js模块导入排序和执行问题
- python - 如何在行空间中创建自定义尺寸
- reactjs - 在 react 中调用 useMutation 钩子后重新获取查询不起作用
- python - 我试图获取两列中两个字符串之间的距离并将结果放在下一个第三列,但为什么它对所有人产生相同的结果?
- java - Android 11 之后在内部存储中创建目录
- jquery - 错误使用 removeProp() 时 jquery Migrate 插件没有警告