python-3.x - 通过 ModelCheckpoint 回调保存整个模型不起作用
问题描述
我试图在每个时代之后使用 ModelCheckpoint 回调保存整个模型。训练后,如果我尝试加载保存的模型并进行评估,则不会加载模型权重。为什么这个 load_model 不能加载模型权重?
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0
# Use smaller subset -- speeds things up
x_train = x_train[:10000]
y_train = y_train[:10000]
x_test = x_test[:1000]
y_test = y_test[:1000]
def get_test_accuracy(model, x_test, y_test):
test_loss, test_acc = model.evaluate(x=x_test, y=y_test, verbose=0)
print('accuracy: {acc:0.3f}'.format(acc=test_acc))
def get_new_model():
model = Sequential([
Conv2D(filters=16, input_shape=(32, 32, 3), kernel_size=(3, 3),
activation='relu', name='conv_1'),
Conv2D(filters=8, kernel_size=(3, 3), activation='relu', name='conv_2'),
MaxPooling2D(pool_size=(4, 4), name='pool_1'),
Flatten(name='flatten'),
Dense(units=32, activation='relu', name='dense_1'),
Dense(units=10, activation='softmax', name='dense_2')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# Create Tensorflow checkpoint object
checkpoint_path = "model_checkpoints"
checkpoint = ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=False,
save_freq="epoch",
verbose=1)
# Create and fit model with checkpoint
model = get_new_model()
model.fit(x_train,
y_train,
epochs=3,
callbacks=[checkpoint])
# Get the model's test accuracy
get_test_accuracy(model,x_test,y_test)
# Reload model from scratch
model = load_model(checkpoint_path)
get_test_accuracy(model,x_test,y_test)
加载保存的模型 load_model 后的精度与训练模型的精度不同。
解决方案
推荐阅读
- excel - 对工作表进行拼写检查然后在不删除规则的情况下将其锁定的宏
- java - 使用原始 128 ASCII 表使用仿射密码加密/解密字符串
- r - 使用命令行 R 的工作流程?
- mysql - 如何记录谁在我的sql中编辑了一个表
- javascript - 如何在 JavaScript 中从 .csv 文件中读取特殊字符
- scala - 在简单教程示例上使用 sbt run 不执行
- windows-7 - Cuckoo Sandbox 不生成 memory.dmp
- android - Unity 构建和运行不适用于 android 设备
- visual-studio-code - 如何在 VSCode 中选择整行
- r - 如何在R中的表情符号之间拆分字符串