首页 > 解决方案 > Tensorflow ModelCheckpoint不保存模型,重新加载后没有丢失

问题描述

回调正在保存检查点文件,但不是SavedModel model.pb文件。此外,当我从检查点加载模型时,它不会重新加载'val_loss'我正在调整"save_best_model"的模型。

我尝试model.save()仅在最佳迭代中使用,但无法使其正常工作,使用ModelCheckpoint回调会更方便。

这是相关代码

LOSS = tf.keras.losses.MeanSquaredError(),

#multi output 3 categories from 0 to 1
model = ImgToClassSimpleContinuous(img_height, img_width)

checkpoint_filename = "../chkpts/ImgToClassSimpleContinuous/checkpoint_dir"
model.load_weights(checkpoint_filename)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filename,
                                                  verbose=1,mode='min', monitor="val_loss", save_best_only=True, save_weights_only=False)

model.compile(
    optimizer='adam',
    loss = [LOSS, LOSS, LOSS],
    metrics=['mse'])

model.fit(
      dataset_to_use, 
      validation_data = dataset_validation_batched,
      # validation_steps=50,
      epochs=MAX_EPOCHS,
      batch_size=BATCH_SIZE,
      callbacks=[cp_callback]
    )

class ImgToClassSimpleContinuous(Model):
  '''
   pair with loss = categorical_crossentropy
  '''
  in_types = [DataType.d]
  out_types = [DataType.tlc, DataType.tls, DataType.tll]

  def __init__(self, img_height, img_width, *args, **kwargs):
    super().__init__(ImgToClassSimple, *args, **kwargs)

    initializer = 'he_normal'

    input_shape = (img_height, img_width, 1)
    inputs = tf.keras.Input(shape=input_shape)

    flat_pix = layers.Flatten()(inputs)

    x = layers.Conv2D(8, 3, padding='same', kernel_initializer=initializer)(inputs)
    x = layers.PReLU()(x)
    x = layers.Conv2D(8, 3, padding='same', kernel_initializer=initializer)(x)
    x = layers.PReLU()(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv2D(16, 3, padding='same', kernel_initializer=initializer)(x)
    x = layers.PReLU()(x)
    x = layers.Conv2D(16, 3, padding='same', kernel_initializer=initializer)(x)
    x = layers.PReLU()(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.BatchNormalization()(x)


    t = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(x)
    t = layers.PReLU()(t)
    t = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(t)
    t = layers.PReLU()(t)
    t = layers.MaxPooling2D(pool_size=(2, 2))(t)
    t = layers.BatchNormalization()(t)
    t = tf.keras.layers.GlobalAveragePooling2D()(t)
    t = layers.Flatten()(t)

    s = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(x)
    s = layers.PReLU()(s)
    s = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(s)
    s = layers.PReLU()(s)
    s = layers.MaxPooling2D(pool_size=(2, 2))(s)
    s = layers.BatchNormalization()(s)
    s = tf.keras.layers.GlobalAveragePooling2D()(s)
    s = layers.Flatten()(s)

    l = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(x)
    l = layers.PReLU()(l)
    l = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(l)
    l = layers.PReLU()(l)
    l = layers.MaxPooling2D(pool_size=(2, 2))(l)
    l = layers.BatchNormalization()(l)
    l = tf.keras.layers.GlobalAveragePooling2D()(l)
    l = layers.Flatten()(l)

    t = layers.Dense(1, activation='sigmoid')(t)
    s = layers.Dense(1, activation='sigmoid')(s)
    l = layers.Dense(1, activation='sigmoid')(l)
    # A Dense classifier with a single unit (binary classification)

    self.model = tf.keras.Model(inputs, [t, s, l])

    tf.keras.utils.plot_model(self.model, to_file="...", show_shapes=True)

  def call(self, x):
          return self.model(x)

标签: tensorflowkerasmodelcallbackcheckpoint

解决方案


推荐阅读