首页 > 解决方案 > Load models in Keras

问题描述

I use this code to load a model in Keras using a customer metric (AUC) but this does not work. Could you help me to solve that problem ?

train_datagen = ImageDataGenerator(rescale=1/255)
val_datagen = ImageDataGenerator(rescale=1/255)

train_generator = train_datagen.flow_from_directory(
                        train_dir,
                        target_size=(32, 32),
                        batch_size=10,
                        class_mode='binary')
val_generator = val_datagen.flow_from_directory(
                        val_dir, 
                        target_size=(32, 32),
                        batch_size=10,
                        class_mode='binary')

model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', 
              optimizer='rmsprop', 
              metrics=[keras.metrics.AUC(name='auc')])

history = model.fit_generator(train_generator,
                              steps_per_epoch=1405,
                              epochs=1,
                              validation_data=val_generator,
                              validation_steps=10)

model.save('baseline.h5')

model1 = models.load_model('baseline.h5')

I got a ValueError

ValueError: Unknown metric function: {'class_name': 'AUC', 'config': {'name': 'auc', 'dtype': 'float32', 'num_thresholds': 200, 'curve': 'ROC', 'summation_method': 'interpolation', 'thresholds': [0.005025125628140704, 0.010050251256281407, 0.01507537688442211, 0.020100502512562814

EDIT : I add the imports. I have heard about the argument 'customer_objects' in the load_model method. But I tried : 'custom_object'={'auc':keras.metrics.AUC(name='auc')}

from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
from keras import models
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import os
from sklearn import metrics
from tensorflow import keras

标签: kerasmodelauc

解决方案


Just don't compile the model:

model1 = models.load_model('baseline.h5', compile=False)
model1.compile(loss='binary_crossentropy', 
              optimizer='rmsprop', 
              metrics=[keras.metrics.AUC()])

推荐阅读