首页 > 解决方案 > 使用 Keras 的 CNN 模型精度差

问题描述

我需要建议。当仅使用 CIFAR10 数据集的一个子集(仅使用 10000 个数据,每类 1000 个)时,使用 Keras 构建 CNN 模型时,我得到了一个非常差的结果(10% 的准确度)。如何提高准确性?我尝试更改/增加纪元,但结果仍然相同。这是我的 CNN 架构:

cnn = models.Sequential()
cnn.add(layers.Conv2D(25, (3, 3), input_shape=(32, 32, 3)))
cnn.add(layers.MaxPooling2D((2, 2)))
cnn.add(layers.Activation('relu'))
cnn.add(layers.Conv2D(50, (3, 3)))
cnn.add(layers.MaxPooling2D((2, 2)))
cnn.add(layers.Activation('relu'))
cnn.add(layers.Conv2D(100, (3, 3)))
cnn.add(layers.MaxPooling2D((2, 2)))
cnn.add(layers.Activation('relu'))
cnn.add(layers.Flatten())
cnn.add(layers.Dense(100))
cnn.add(layers.Activation('relu'))
cnn.add(layers.Dense(10))
cnn.add(layers.Activation('softmax'))

编译和拟合:

EPOCHS = 200
BATCH_SIZE = 10
LEARNING_RATE = 0.1

cnn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
            loss='binary_crossentropy',
            metrics=['accuracy'])

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1)
mc = ModelCheckpoint(filepath=checkpoint_path, monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)

history_cnn = cnn.fit(train_images, train_labels, epochs=EPOCHS, batch_size=BATCH_SIZE,
                validation_data=(test_images, test_labels),callbacks=[es, mc],verbose=0)

我使用的数据是 CIFAR10,但我每个类只拍摄 1000 张图像,所以总数据只有 10000。我使用规范化来预处理数据。

标签: pythontensorflowmachine-learningkerasconv-neural-network

解决方案


首先,问题是损失。您的数据集是一个多类问题,不是二元问题,也不是多标签问题

如此处所述:

这些类是完全互斥的。汽车和卡车之间没有重叠。“汽车”包括轿车、SUV 之类的东西。“卡车”只包括大卡车。两者都不包括皮卡车。

在这种情况下,建议使用categorical crossentropy. 请记住,如果您的标签是稀疏的(使用 0 到 999 之间的数字编码)而不是一个热编码向量([0, 0, 0 ... 1, 0, 0]),您应该使用sparse categorical crossentropy.

  • 不稀疏(标签编码为向量 [0, 0, 1,....0])

    cnn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
    
  • 稀疏(标签编码为 (0, ... 999) 中的数字)

    cnn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
    

此外,学习率非常高(0.1)。例如,我建议您从较低的 (0.001) 开始。

这篇文章也与您的问题有关

编辑:我的错,对于过滤器的数量,这是一种具有越来越多过滤器的通用方法


推荐阅读