keras - 使用重新缩放时 model.predict 的意外输出
问题描述
首先:我知道这篇文章,但它没有提供答案。
我正在像这样构建我的模型:
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
from keras.preprocessing.image import ImageDataGenerator # for data augmentation
import pandas as pd # to save .csv files
from time import perf_counter # to track runtime
from keras.metrics import TrueNegatives, TruePositives, FalseNegatives, FalsePositives
def build_model(dimension):
model = Sequential()
model.add(Conv2D(32, (11,11), activation='relu',
input_shape=(dimension, dimension, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten()) # to prepare for dropout
model.add(Dropout(0.2)) # to prevent overfitting
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy',
TruePositives(),
TrueNegatives(),
FalsePositives(),
FalseNegatives()
]
)
return model
def train_model(epoch, batch_size, run, subrun):
dimension = 200
model = build_model(dimension)
train_datagen = ImageDataGenerator(validation_split=0.2,
# samplewise_std_normalization=True,
rotation_range=40,
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
width_shift_range=0.2,
height_shift_range=0.2,
fill_mode='nearest'
)
training_set = train_datagen.flow_from_directory('train6',
target_size=(dimension, dimension),
color_mode='rgb', # default
class_mode='binary',
batch_size=batch_size,
save_to_dir=None,
interpolation='nearest',
subset='training')
validation_set = train_datagen.flow_from_directory('train6',
target_size=(dimension, dimension),
color_mode='rgb', # default
class_mode='binary',
batch_size=batch_size,
save_to_dir=None,
# if 'str', saves augmented images for visualisation
interpolation='nearest',
subset='validation')
start_time = perf_counter() # start counting
history = model.fit_generator(training_set,
epochs=epoch,
steps_per_epoch=training_set.samples // batch_size,
validation_data=validation_set,
validation_steps=validation_set.samples // batch_size,
verbose=2)
stop_time = perf_counter() # stop counting
# saving trained model & history file
model.save_weights('models/cat_dog_classifier_{0}_{1}.h5'.format(run, subrun)) # save model weights
hist_pd = pd.DataFrame(history.history) # making panda file of history.history
hist_csv_file = 'histories/history_{0}_{1}.csv'.format(run, subrun) # defining name for csv file
with open(hist_csv_file, mode='w') as f: # saving the pd file as csv
hist_pd.to_csv(f)
return stop_time - start_time
我使用以下代码来获取概率:
from build_model import build_model
from keras.preprocessing import image
import numpy as np
run = 'A28'
subrun = 1
dimension = 200
# build model
model = build_model(dimension)
model.load_weights('models/cat_dog_classifier_{0}_{1}.h5'.format(run, subrun))
# Get test image ready
amount_of_images = 10
predictions = np.zeros((amount_of_images, 2))
labels = np.zeros(amount_of_images)
for i in range(amount_of_images):
image_name = 1 + i # choose what image to start from
test_image = image.load_img('test1/{}.jpg'.format(image_name), target_size=(dimension, dimension))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
label = model.predict_classes(test_image, batch_size=1)
labels[i] = label
prediction = model.predict(test_image, batch_size=1)
print(prediction)
print(labels)
当我在不使用重新缩放或标准化的情况下训练我的模型时,预测是预期的概率。但是,当我使用其中任何一个时,它只返回 0 和 1(与 predict_classes 相同的标签)。我试图运行上面链接中提供的虚拟代码,它按预期工作;我想这是有道理的,因为当我没有使用重新缩放时脚本也可以正常运行。但是,我真的很想能够使用重新缩放。有谁知道出了什么问题?
解决方案
推荐阅读
- amazon-web-services - 如何从 docker 容器将文件上传到 S3 存储桶?
- javascript - 如何将所有鼠标输入路由到特定的 HTML 元素?
- javascript - 如何通过 ID 查找 HTML 元素是否存在并输入不同的内容
- mysql - 当 LOAD DATA LOCAL INFILE 进入 MySQL 表时,每 10000 条记录失败
- excel - 复制特定行并删除空白
- javascript - 侧边栏切换 - 侧边栏切换时主要内容不会拉伸到全宽
- php - 如何从 laravel redirectResponse 获取重定向路由?
- python - 使用 Python 抓取 CSS 样式及其 HTML 标签
- objective-c - 使用方法 __Block_byref_object_copy_ 到达意外的 objc 断点
- reactjs - 多层 forwardRef 和 useImperativeHandle