首页 > 解决方案 > 通过 ModelCheckpoint 回调保存整个模型不起作用

问题描述

我试图在每个时代之后使用 ModelCheckpoint 回调保存整个模型。训练后,如果我尝试加载保存的模型并进行评估,则不会加载模型权重。为什么这个 load_model 不能加载模型权重?

import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

# Use smaller subset -- speeds things up
x_train = x_train[:10000]
y_train = y_train[:10000]
x_test = x_test[:1000]
y_test = y_test[:1000]

def get_test_accuracy(model, x_test, y_test):
  test_loss, test_acc = model.evaluate(x=x_test, y=y_test, verbose=0)
  print('accuracy: {acc:0.3f}'.format(acc=test_acc))


def get_new_model():
    model = Sequential([
        Conv2D(filters=16, input_shape=(32, 32, 3), kernel_size=(3, 3), 
               activation='relu', name='conv_1'),
        Conv2D(filters=8, kernel_size=(3, 3), activation='relu', name='conv_2'),
        MaxPooling2D(pool_size=(4, 4), name='pool_1'),
        Flatten(name='flatten'),
        Dense(units=32, activation='relu', name='dense_1'),
        Dense(units=10, activation='softmax', name='dense_2')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model


# Create Tensorflow checkpoint object
checkpoint_path = "model_checkpoints"

checkpoint = ModelCheckpoint(filepath=checkpoint_path,
                             save_weights_only=False,
                             save_freq="epoch",
                             verbose=1)

# Create and fit model with checkpoint
model = get_new_model()

model.fit(x_train,
          y_train,
          epochs=3,
          callbacks=[checkpoint])

# Get the model's test accuracy

get_test_accuracy(model,x_test,y_test)

# Reload model from scratch
model = load_model(checkpoint_path)
get_test_accuracy(model,x_test,y_test)

在此处输入图像描述

加载保存的模型 load_model 后的精度与训练模型的精度不同。

标签: python-3.xdeep-learningcallbacktensorflow2.0tf.keras

解决方案


推荐阅读