keras - Keras CNN 模型总是返回 [0.5 0.5]
问题描述
谁能帮我解决这个问题?我的模型总是返回 1 类。源代码如下:我想对图像进行分类(二进制)。该模型产生了良好的准确性。现在,我需要测试模型有哪些新图像,我加载了模型并尝试预测类,但它总是返回 0。
batch_size = 30
epochs = 50
IMG_HEIGHT = 224
IMG_WIDTH = 224
image_gen_train = ImageDataGenerator(
rotation_range=15,
width_shift_range=0.01,
height_shift_range=0.01,
rescale=1./255,
shear_range=0.1,
fill_mode='nearest',
validation_split=0.2)
train_data_gen = image_gen_train.flow_from_directory(batch_size=batch_size,
directory=dataset_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
subset='training',
class_mode='binary') # set as training data
val_data_gen = image_gen_train.flow_from_directory(batch_size=batch_size,
directory=dataset_dir,
shuffle=False,
target_size=(IMG_HEIGHT, IMG_WIDTH),
subset='validation',
class_mode='binary') # set as validation data
sample_training_images, _ = next(train_data_gen)
# This function will plot images in the form of a grid with 1 row and 5 columns where images are placed in each column.
def plotImages(images_arr):
fig, axes = plt.subplots(1, 4, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.savefig('xray_new.png')
plt.clf()
plotImages(sample_training_images[:4])
#the model
model = Sequential()
model.add(Conv2D(64, kernel_size= (3,3), input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),padding='same'))
model.add(BatchNormalization(momentum=0.5, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(Conv2D(64, kernel_size=(3,3), padding='same'))
model.add(BatchNormalization(momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.35))
model.add(Conv2D(128, kernel_size =(3,3),padding='same'))
model.add(BatchNormalization(momentum=0.2, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(BatchNormalization(momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(Conv2D(128,(3,3), padding='same' ))
model.add(BatchNormalization(momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.35))
model.add(Conv2D(256, kernel_size = (3,3), padding='same'))
model.add(BatchNormalization(momentum=0.2, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(Conv2D(256, kernel_size= (3,3) ,padding='same'))
model.add(BatchNormalization(momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"))
model.add(LeakyReLU(alpha=0.1))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.35))
model.add(Flatten())
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.1))
model.add(BatchNormalization())
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
#model.summary()
model.save("model.h5")
history = model.fit_generator(
train_data_gen,
steps_per_epoch= train_data_gen.samples // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps= val_data_gen.samples // batch_size,verbose=1)
但是当我测试模型时,它总是输出 1 类:
filepath = 'model.h5'
model = load_model(filepath,compile=True)
def test(model,image_path):
test_image = image.load_img(image_path, target_size=(IMG_HEIGHT, IMG_WIDTH))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
# predict the result
prediction = model.predict(test_image)
print(prediction)
if prediction[0][0] == 1:
my = 'Normal'
else:
my = 'Asthma'
print(my)
prediction = np.argmax(prediction)
labels = (train_data_gen.class_indices)
labels = dict((v,k) for k,v in labels.items())
return labels[prediction]
我真的很感谢你的帮助!
解决方案
我认为您忘记在测试部分中将输入图像除以255.
。
推荐阅读
- azure-active-directory - Azure B2C - 自定义策略 - 连接身份提供者时出错
- python - Python regex - 将名称与合法形式匹配
- java - 地图结构
/到 / 映射 - c++ - 移动智能指针两次与复制
- cybersource - 无法从沙盒复制支付工具令牌创建操作,出现“无效的配置文件所有者”错误
- javascript - 在饼图上显示标签而不是数据值 Chart.js
- python - teensy bord 与系统之间使用 python 进行串行通信
- mysql - 优化运行时间更长的mysql查询
- ios - 反应原生:未找到活动的 iOS 设备
- architecture - 添加实体和列数据库的新属性会破坏 OCP(开放封闭原则)吗?