python - 用于二进制分类的 cnn 模型总是返回 1
问题描述
我为二进制分类创建了一个 CNN 模型。我使用了一个包含 300 张图像的平衡数据库。我知道这是一个小型数据库,但我使用了数据增强。拟合模型后,我在验证集上得到了 86% 的 val_accuracy,但是当我想打印每张图片的概率时,我得到第一类的大多数图片的概率 1,甚至所有概率都 > 0.5,所有概率都为 1来自第二类的图像。
这是我的模型:
model = keras.Sequential([
layers.InputLayer(input_shape=[128, 128, 3]),
preprocessing.Rescaling(scale=1/255),
preprocessing.RandomContrast(factor=0.10),
preprocessing.RandomFlip(mode='horizontal'),
preprocessing.RandomRotation(factor=0.10),
layers.BatchNormalization(renorm=True),
layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding='same'),
layers.MaxPool2D(),
layers.BatchNormalization(renorm=True),
layers.Conv2D(filters=128, kernel_size=3, activation='relu', padding='same'),
layers.MaxPool2D(),
layers.BatchNormalization(renorm=True),
layers.Conv2D(filters=256, kernel_size=3, activation='relu', padding='same'),
layers.Conv2D(filters=256, kernel_size=3, activation='relu', padding='same'),
layers.MaxPool2D(),
layers.BatchNormalization(renorm=True),
layers.Flatten(),
layers.Dense(8, activation='relu'),
layers.Dense(1, activation='sigmoid'),])
编辑:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss='binary_crossentropy',
metrics=['binary_accuracy'],
)
history = model.fit(
ds_train,
validation_data=ds_valid,
epochs=50,
)
谢谢你。
解决方案
像 vgg16 这样的预训练模型可以很好地完成所有工作,不需要使模型变得非常复杂。所以试试下面的代码:
base_model = keras.applications.VGG16(
weights='imagenet',
input_shape=(128, 128, 3),
include_top=False)
base_model.trainable = True
inputs = keras.Input(shape=(128, 128, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
如果您希望模型快速训练,请将 base_model.trainable 设置为 False,并将其设置为 True 以获得更准确的结果。请注意,我使用GlobalAveragePooling2D层而不是 Flatten,以减少参数数量并取消堆叠特征。
推荐阅读
- c# - 如何在角度弹出模式中找到动态 XPath?
- three.js - Three.js Mixer.time = n 不会在第 n 个搅拌机帧上播放,mixer.time=n/1000 也不会
- android - ViewPager2 单次滑动仅滑动一次
- c# - 使用 EF Core FromSql() 在 DbSet 上进行全文索引搜索
同一张表上的 ValueObject 应该没有 LEFT JOIN - android - 将csv文件导入SQLite并访问Android Studio
- c# - NServiceBus 序列化问题
- wordpress - 多个 Rest API 条目出现但帖子已被删除
- c++ - 如何使用二维数组构建如下坐标系?C++
- python - 是否可以重新训练谷歌的通用句子编码器,以便在编码句子时考虑关键字?
- reactjs - 如何使用 React 创建确认对话框组件