首页 > 解决方案 > 使用 TensorFlow 保存检查点

问题描述

我的 CNN 模型有 3 个文件夹,它们是train_data, val_data, test_data.

当我训练我的模型时,我发现准确度可能会有所不同,有时最后一个时期并没有显示出最好的准确度。例如,最后一个 epoch 的准确度是 71%,但我发现在较早的 epoch 中准确度更高。我想保存具有更高准确性的那个时代的检查点,然后使用该检查点来预测我的模型test_data

我训练我的模型train_data并预测val_data并保存模型的检查点,如下所示:

    print("{} Saving checkpoint of model...". format(datetime.now()))
    checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
    save_path = saver.save(session, checkpoint_path)

在开始之前tf.Session()我有这一行:

saver = tf.train.Saver()

我想知道如何保存具有更高准确性的最佳时期,然后将此检查点用于我的test_data

标签: pythonpython-3.xtensorflowtensorflow-estimator

解决方案


tf.train.Saver()文档描述了以下内容:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

请注意,如果您传递global_step给保护程序,您将生成包含全局步骤编号的检查点文件。我通常每 X 分钟保存一次检查点,然后回来查看结果并在适当的步长值处选择一个检查点。如果您使用 tensorboard,您会发现这很直观,因为您的所有图表也可以按全局步骤显示。

https://www.tensorflow.org/api_docs/python/tf/train/Saver


推荐阅读