首页 > 解决方案 > 如何在 tensorflow 2.4.0 中使用 modelcheckpointcallback 保存和加载?

问题描述

我在 tensorflow 2.4.0 中使用 keras API 使用这行代码在不同的步骤保存我的模型:

tf.keras.callbacks.ModelCheckpoint(os.path.join(log_dir, "checkpoint_{epoch:02d}.tf"),
                                   save_freq=train_steps_per_epoch * cfg.log.checkpoint_save_every_epochs,
                                   save_weights_only=False))

保存模型后,我遇到一些警告:

[2021-03-17 12:11:09,974][absl][WARNING] - Found untraced functions such as nl_0_layer_call_fn, nl_0_layer_call_and_return_conditional_losses, nl_1_layer_call_fn, nl_1_layer_call_and_return_conditional_losses, conv2d_layer_call_fn while saving (showing 5 of 280). These functions will not be directly callable after loading.

当我使用这行代码加载模型时:

        model = tf.keras.models.load_model(os.path.join(train_dir, f'checkpoint_{cfg.training.checkpoint_epoch:02d}.tf'))

我有这个错误:

Traceback (most recent call last):
  File "run/linear_classifier_evaluation.py", line 51, in run
    model = tf.keras.models.load_model(os.path.join(train_dir, f'checkpoint_{cfg.training.checkpoint_epoch:02d}.tf'))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py", line 212, in load_model
    return saved_model_load.load(filepath, compile, options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py", line 138, in load
    keras_loader.load_layers(compile=compile)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py", line 376, in load_layers
    node_metadata.metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py", line 417, in _load_layer
    obj, setter = self._revive_from_config(identifier, metadata, node_id)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py", line 434, in _revive_from_config
    self._revive_graph_network(metadata, node_id) or
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py", line 471, in _revive_graph_network
    inputs=[], outputs=[], name=config['name'])
KeyError: 'name'

我在网上搜索了这个错误,我读到有人建议在我的模型类中实现 get_config 和 from_config 方法,但我没有使用 H5 格式,所以如果我正确理解了keras 教程和其他,我不应该这样做tf1.x 的解决方案。

我很乐意欢迎任何关于在哪里寻找的帮助或建议。

标签: tensorflowkeras

解决方案


推荐阅读