tensorflow - Cannot load model weights in TensorFlow 2
问题描述
I cannot load model weights after saving them in TensorFlow 2.2. Weights appear to be saved correctly (I think), however, I fail to load the pre-trained model.
My current code is:
segmentor = sequential_model_1()
discriminator = sequential_model_2()
def save_model(ckp_dir):
# create directory, if it does not exist:
utils.safe_mkdir(ckp_dir)
# save weights
segmentor.save_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'))
discriminator.save_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'))
def load_pretrained_model(ckp_dir):
try:
segmentor.load_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'), skip_mismatch=True)
discriminator.load_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'), skip_mismatch=True)
print('Loading pre-trained model from: {0}'.format(ckp_dir))
except ValueError:
print('No pre-trained model available.')
Then I have the training loop:
# training loop:
for epoch in range(num_epochs):
for image, label in dataset:
train_step()
# save best model I find during training:
if this_is_the_best_model_on_validation_set():
save_model(ckp_dir='logs_dir')
And then, at the end of the training "for loop", I want to load the best model and do a test with it. Hence, I run:
# load saved model and do a test:
load_pretrained_model(ckp_dir='logs_dir')
test()
However, this results in a ValueError
. I checked the directory where the weights should be saved, and there they are!
Any idea what is wrong with my code? Am I loading the weights incorrectly?
Thank you!
解决方案
Ok here is your problem - the try-except
block you have is obscuring the real issue. Removing it gives the ValueError
:
ValueError: When calling model.load_weights, skip_mismatch can only be set to True when by_name is True.
There are two ways to mitigate this - you can either call load_weights
with by_name=True
, or remove skip_mismatch=True
depending on your needs. Either case works for me when testing your code.
Another consideration is that you when you store both the discriminator and segmentor checkpoints to the log directory, you overwrite the checkpoint
file each time. This contains two strings that give the path to the specific model checkpoint files. Since you save discriminator second, every time this file will say discriminator with no reference to segmentor. You can mitigate this by storing each model in two subdirectories in the log directory instead, i.e.
logs_dir/
+ discriminator/
+ checkpoint
+ ...
+ segmentor/
+ checkpoint
+ ...
Although in the current state your code would work in this case.
推荐阅读
- python - Pytorch:计算精度?
- react-native - 应用程序处于完全状态时的通知声音延迟
- python - 在用户的两个输入之间找到平均值
- java - 无法使用java从文本文件中解码base64图像
- c++ - Header Only Library 多重定义错误
- swift - TabView 项目点击滚动到它的 NavigationView 顶部
- python - FastAPI:如何将任意属性 dict 映射到 pydantic 模型以让 swagger 显示细节
- java - java.io.FileNotFoundException:/opt/tomcat/conf/tomcat-users.xml(打开的文件太多)
- r - stat_cor:显示 r 值、p 值和显着性(***、**、*)
- javascript - 如何更改开发位置?