首页 > 解决方案 > Cannot load model weights in TensorFlow 2

问题描述

I cannot load model weights after saving them in TensorFlow 2.2. Weights appear to be saved correctly (I think), however, I fail to load the pre-trained model.

My current code is:

segmentor = sequential_model_1()
discriminator = sequential_model_2()

def save_model(ckp_dir):
    # create directory, if it does not exist:
    utils.safe_mkdir(ckp_dir)

    # save weights
    segmentor.save_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'))
    discriminator.save_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'))

def load_pretrained_model(ckp_dir):
    try:
        segmentor.load_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'), skip_mismatch=True)
        discriminator.load_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'), skip_mismatch=True)
        print('Loading pre-trained model from: {0}'.format(ckp_dir))
    except ValueError:
        print('No pre-trained model available.')

Then I have the training loop:

# training loop:
for epoch in range(num_epochs):

    for image, label in dataset:
        train_step()

    # save best model I find during training:
    if this_is_the_best_model_on_validation_set():
        save_model(ckp_dir='logs_dir')

And then, at the end of the training "for loop", I want to load the best model and do a test with it. Hence, I run:

# load saved model and do a test:
load_pretrained_model(ckp_dir='logs_dir')
test()

However, this results in a ValueError. I checked the directory where the weights should be saved, and there they are!

Any idea what is wrong with my code? Am I loading the weights incorrectly?

Thank you!

标签: tensorflowsaveload

解决方案


Ok here is your problem - the try-except block you have is obscuring the real issue. Removing it gives the ValueError:

ValueError: When calling model.load_weights, skip_mismatch can only be set to True when by_name is True.

There are two ways to mitigate this - you can either call load_weights with by_name=True, or remove skip_mismatch=True depending on your needs. Either case works for me when testing your code.

Another consideration is that you when you store both the discriminator and segmentor checkpoints to the log directory, you overwrite the checkpoint file each time. This contains two strings that give the path to the specific model checkpoint files. Since you save discriminator second, every time this file will say discriminator with no reference to segmentor. You can mitigate this by storing each model in two subdirectories in the log directory instead, i.e.

logs_dir/
    + discriminator/
        + checkpoint
        + ...
    + segmentor/
        + checkpoint
        + ...

Although in the current state your code would work in this case.


推荐阅读