keras - Load models in Keras
问题描述
I use this code to load a model in Keras using a customer metric (AUC) but this does not work. Could you help me to solve that problem ?
train_datagen = ImageDataGenerator(rescale=1/255)
val_datagen = ImageDataGenerator(rescale=1/255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(32, 32),
batch_size=10,
class_mode='binary')
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=(32, 32),
batch_size=10,
class_mode='binary')
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=[keras.metrics.AUC(name='auc')])
history = model.fit_generator(train_generator,
steps_per_epoch=1405,
epochs=1,
validation_data=val_generator,
validation_steps=10)
model.save('baseline.h5')
model1 = models.load_model('baseline.h5')
I got a ValueError
ValueError: Unknown metric function: {'class_name': 'AUC', 'config': {'name': 'auc', 'dtype': 'float32', 'num_thresholds': 200, 'curve': 'ROC', 'summation_method': 'interpolation', 'thresholds': [0.005025125628140704, 0.010050251256281407, 0.01507537688442211, 0.020100502512562814
EDIT : I add the imports. I have heard about the argument 'customer_objects' in the load_model method. But I tried : 'custom_object'={'auc':keras.metrics.AUC(name='auc')}
from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
from keras import models
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import os
from sklearn import metrics
from tensorflow import keras
解决方案
Just don't compile the model:
model1 = models.load_model('baseline.h5', compile=False)
model1.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=[keras.metrics.AUC()])
推荐阅读
- audio - 将 FLAC 通道分配更改为去相关左侧会导致丢失同步错误状态
- django - 使用表单集时添加和删除按钮
- ubuntu - 由 NASM 制作的 32 位可执行文件无法在 WSL2 上运行
- r - 无法正确打印我的 for 循环代码中的每次迭代
- asp.net-mvc - InvalidOperationException:无法创建“Microsoft.AspNetCore.Identity.UserManager”类型的实例
- react-native - 发布 apk react-native 不加载图片
- flutter - Flutter - 如何对通知进行分组
- excel - VBA IE 下拉选择
- sympy - Sympy 求解的 ODE 不满足给定的初始条件
- javascript - 在 CSS 和 JS 中屏蔽 3D 旋转图像