tensorflow - Tensorflow ModelCheckpoint不保存模型,重新加载后没有丢失
问题描述
回调正在保存检查点文件,但不是SavedModel
model.pb
文件。此外,当我从检查点加载模型时,它不会重新加载'val_loss'
我正在调整"save_best_model"
的模型。
我尝试model.save()
仅在最佳迭代中使用,但无法使其正常工作,使用ModelCheckpoint
回调会更方便。
这是相关代码
LOSS = tf.keras.losses.MeanSquaredError(),
#multi output 3 categories from 0 to 1
model = ImgToClassSimpleContinuous(img_height, img_width)
checkpoint_filename = "../chkpts/ImgToClassSimpleContinuous/checkpoint_dir"
model.load_weights(checkpoint_filename)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filename,
verbose=1,mode='min', monitor="val_loss", save_best_only=True, save_weights_only=False)
model.compile(
optimizer='adam',
loss = [LOSS, LOSS, LOSS],
metrics=['mse'])
model.fit(
dataset_to_use,
validation_data = dataset_validation_batched,
# validation_steps=50,
epochs=MAX_EPOCHS,
batch_size=BATCH_SIZE,
callbacks=[cp_callback]
)
class ImgToClassSimpleContinuous(Model):
'''
pair with loss = categorical_crossentropy
'''
in_types = [DataType.d]
out_types = [DataType.tlc, DataType.tls, DataType.tll]
def __init__(self, img_height, img_width, *args, **kwargs):
super().__init__(ImgToClassSimple, *args, **kwargs)
initializer = 'he_normal'
input_shape = (img_height, img_width, 1)
inputs = tf.keras.Input(shape=input_shape)
flat_pix = layers.Flatten()(inputs)
x = layers.Conv2D(8, 3, padding='same', kernel_initializer=initializer)(inputs)
x = layers.PReLU()(x)
x = layers.Conv2D(8, 3, padding='same', kernel_initializer=initializer)(x)
x = layers.PReLU()(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(16, 3, padding='same', kernel_initializer=initializer)(x)
x = layers.PReLU()(x)
x = layers.Conv2D(16, 3, padding='same', kernel_initializer=initializer)(x)
x = layers.PReLU()(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.BatchNormalization()(x)
t = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(x)
t = layers.PReLU()(t)
t = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(t)
t = layers.PReLU()(t)
t = layers.MaxPooling2D(pool_size=(2, 2))(t)
t = layers.BatchNormalization()(t)
t = tf.keras.layers.GlobalAveragePooling2D()(t)
t = layers.Flatten()(t)
s = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(x)
s = layers.PReLU()(s)
s = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(s)
s = layers.PReLU()(s)
s = layers.MaxPooling2D(pool_size=(2, 2))(s)
s = layers.BatchNormalization()(s)
s = tf.keras.layers.GlobalAveragePooling2D()(s)
s = layers.Flatten()(s)
l = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(x)
l = layers.PReLU()(l)
l = layers.Conv2D(32, 3, padding='same', kernel_initializer=initializer)(l)
l = layers.PReLU()(l)
l = layers.MaxPooling2D(pool_size=(2, 2))(l)
l = layers.BatchNormalization()(l)
l = tf.keras.layers.GlobalAveragePooling2D()(l)
l = layers.Flatten()(l)
t = layers.Dense(1, activation='sigmoid')(t)
s = layers.Dense(1, activation='sigmoid')(s)
l = layers.Dense(1, activation='sigmoid')(l)
# A Dense classifier with a single unit (binary classification)
self.model = tf.keras.Model(inputs, [t, s, l])
tf.keras.utils.plot_model(self.model, to_file="...", show_shapes=True)
def call(self, x):
return self.model(x)
解决方案
推荐阅读
- java - Java - 可以使用 return 语句退出程序吗?
- python - 如何删除机器人直接消息?
- javascript - SignalR Group 消息应用程序,创建组消息框导致刷新
- android - Xamarin.Forms Live Player - 由于 API 版本不支持的设备
- typescript - 打字稿类型的 redux-persist MigrationManifest?
- jquery - 如何在 html 模板中使用 contextPath
- webpack - 非相对导入 - Ionic - Webpack - Sourcemaps 在 Chrome 中的一个文件上偏离了几行
- android - 我无法将活动的数据(字符串)发送到活动中包含的片段
- arrays - 将字符串与字符串数组进行比较并计算 Swift 中的匹配项
- ios - iOS Autolayout:可以将约束优先级设置为 999 以使警告静音吗?