python - 错误“IndexError:如何使用 Keras 中的训练模型预测输入图像?
问题描述
我训练了一个模型来对来自 9 个类的图像进行分类,并使用 model.save() 保存它。这是我使用的代码:
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.layers import Dense, Dropout
from keras.models import Model
from keras.optimizers import Adam, SGD
from keras.preprocessing.image import ImageDataGenerator, image
from keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGE = True
# Define some constant needed throughout the script
N_CLASSES = 9
EPOCHS = 2
PATIENCE = 5
TRAIN_PATH= '/Datasets/Train/'
VALID_PATH = '/Datasets/Test/'
MODEL_CHECK_WEIGHT_NAME = 'resnet_monki_v1_chk.h5'
# Define model to be used we freeze the pre trained resnet model weight, and add few layer on top of it to utilize our custom dataset
K.set_learning_phase(0)
model = ResNet50(input_shape=(224,224,3),include_top=False, weights='imagenet', pooling='avg')
K.set_learning_phase(1)
x = model.output
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
output = Dense(N_CLASSES, activation='softmax', name='custom_output')(x)
custom_resnet = Model(inputs=model.input, outputs = output)
for layer in model.layers:
layer.trainable = False
custom_resnet.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
custom_resnet.summary()
# 4. Load dataset to be used
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
traingen = datagen.flow_from_directory(TRAIN_PATH, target_size=(224,224), batch_size=32, class_mode='categorical')
validgen = datagen.flow_from_directory(VALID_PATH, target_size=(224,224), batch_size=32, class_mode='categorical', shuffle=False)
# 5. Train Model we use ModelCheckpoint to save the best model based on validation accuracy
es_callback = EarlyStopping(monitor='val_acc', patience=PATIENCE, mode='max')
mc_callback = ModelCheckpoint(filepath=MODEL_CHECK_WEIGHT_NAME, monitor='val_acc', save_best_only=True, mode='max')
train_history = custom_resnet.fit_generator(traingen, steps_per_epoch=len(traingen), epochs= EPOCHS, validation_data=traingen, validation_steps=len(validgen), verbose=2, callbacks=[es_callback, mc_callback])
model.save('custom_resnet.h5')
它训练成功。为了在新图像上加载和测试这个模型,我使用了下面的代码:
from keras.models import load_model
import cv2
import numpy as np
class_names = ['A', 'B', 'C', 'D', 'E','F', 'G', 'H', 'R']
model = load_model('custom_resnet.h5')
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
img = cv2.imread('/path to image/4.jpg')
img = cv2.resize(img,(224,224))
img = np.reshape(img,[1,224,224,3])
classes = np.argmax(model.predict(img), axis = -1)
print(classes)
它输出:
[1915]
为什么它不给出类的实际值以及为什么索引太大?我只有9节课!
谢谢
解决方案
您已保存原始 resnet_base 而不是您的自定义模型。
你做了model.save('custom_resnet.h5')
但,model = ResNet50(input_shape=(224,224,3),include_top=False, weights='imagenet', pooling='avg')
您需要保存 custom_resnet 模型custom_resnet.save('custom_resnet.h5')
这就是为什么当您使用预测时,您会得到 (1,2048) 个形状特征而不是实际预测。
更新代码:
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.layers import Dense, Dropout
from keras.models import Model
from keras.optimizers import Adam, SGD
from keras.preprocessing.image import ImageDataGenerator, image
from keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGE = True
# Define some constant needed throughout the script
N_CLASSES = 9
EPOCHS = 2
PATIENCE = 5
TRAIN_PATH= '/Datasets/Train/'
VALID_PATH = '/Datasets/Test/'
MODEL_CHECK_WEIGHT_NAME = 'resnet_monki_v1_chk.h5'
# Define model to be used we freeze the pre trained resnet model weight, and add few layer on top of it to utilize our custom dataset
K.set_learning_phase(0)
model = ResNet50(input_shape=(224,224,3),include_top=False, weights='imagenet', pooling='avg')
K.set_learning_phase(1)
x = model.output
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
output = Dense(N_CLASSES, activation='softmax', name='custom_output')(x)
custom_resnet = Model(inputs=model.input, outputs = output)
for layer in model.layers:
layer.trainable = False
custom_resnet.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
custom_resnet.summary()
# 4. Load dataset to be used
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
traingen = datagen.flow_from_directory(TRAIN_PATH, target_size=(224,224), batch_size=32, class_mode='categorical')
validgen = datagen.flow_from_directory(VALID_PATH, target_size=(224,224), batch_size=32, class_mode='categorical', shuffle=False)
# 5. Train Model we use ModelCheckpoint to save the best model based on validation accuracy
es_callback = EarlyStopping(monitor='val_acc', patience=PATIENCE, mode='max')
mc_callback = ModelCheckpoint(filepath=MODEL_CHECK_WEIGHT_NAME, monitor='val_acc', save_best_only=True, mode='max')
train_history = custom_resnet.fit_generator(traingen, steps_per_epoch=len(traingen), epochs= EPOCHS, validation_data=traingen, validation_steps=len(validgen), verbose=2, callbacks=[es_callback, mc_callback])
custom_resnet.save('custom_resnet.h5')
推理代码:
from keras.models import load_model
import cv2
import numpy as np
class_names = ['A', 'B', 'C', 'D', 'E','F', 'G', 'H', 'R']
model = load_model('custom_resnet.h5')
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
img = cv2.imread('/path to image/4.jpg')
img = cv2.resize(img,(224,224))
img = np.reshape(img,[1,224,224,3])
classes = np.argmax(model.predict(img), axis = -1)
print(classes)
推荐阅读
- visual-studio - ASP.NET Core 应用在 VS Code 而非 VS 2017 上启动
- python - 调用 SubpixelConv2D 函数时出现“ValueError: None values not supported”
- html - 如何使输入搜索栏、麦克风图像和两个小框响应?
- rest - PUT 功能与 REST API 中的 POST 有何不同
- css - 页面的导航栏隐藏在幻灯片动画下。如何防止这种情况?我想在顶部单独显示导航栏,动画照常进行
- c# - 如何增加由 . 分隔的字符串值编号(即 1.2.3 == 1.2.4)
- extjs - extjs 问题将 extraParams 发送到restapi
- python - 如何保存图像阅读和代码训练结果?
- asp.net - 防止数据从数据库重新加载到 from
- javascript - 我正在尝试开玩笑地对查询进行“压力测试”-但我无法绕过自动 5000 毫秒超时