python - 恢复 TensorFlow 模型的方式差异
问题描述
我已经看到并尝试了两种方法,但无法理解它有什么区别。以下是我使用的两种方法:
方法 1:
saver = tf.train.import_meta_graph(tf.train.latest_checkpoint(model_path)+".meta")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")
方法二:
saver = tf.train.Saver()
sess =tf.Session()
sess.run(tf.global_variables_initializer())
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")
我想知道的是:
以上两种方法有什么区别?
加载模型的最佳方法是什么?
请让我知道您对此有何建议。
解决方案
我会尽量简明扼要,所以这是我对此事的 2 美分。我将评论您代码的重要行以指出我的想法。
# Importing the meta graph is same as building the same graph from scratch
# creating the same variables, creating the same placeholders and ect.
# Basically you are only importing the graph definition
saver = tf.train.import_meta_graph(tf.train.latest_checkpoint(model_path)+".meta")
sess = tf.Session()
# Absolutely no need to initialize the variables here. They will be initialized
# when the you restore the learned variables.
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")
至于第二种方法:
# You can't create a saver object like this, you will get an error "No variables to save", which is true.
# You haven't created any variables. The workaround for doing this is:
# saver = tf.train.Saver(defer_build=True) and then after building the graph
# ....Graph building code goes here....
# saver.build()
saver = tf.train.Saver()
sess = tf.Session()
# Absolutely no need to initialize the variables here. They will be initialized
# when the you restore the learned variables.
sess.run(tf.global_variables_initializer())
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")
所以第一种方法没有错,但第二种方法完全不正确。不要误会我的意思,但我不喜欢他们中的任何一个。不过,这只是个人口味。另一方面,我想做的是以下内容:
# Have a class that creates the model and instantiate an object of that class
my_trained_model = MyModel()
# This is basically the same as what you are doing with
# saver = tf.train.import_meta_graph(tf.train.latest_checkpoint(model_path)+".meta")
# Then, once I have the graph build, I will create a saver object
saver = tf.train.Saver()
# Then I will create a session
with tf.Session() as sess:
# Restore the trained variables here
saver.restore(sess, model_checkpoint_path)
# Now I can do whatever I want with the my_trained_model object
我希望这对你有帮助。
推荐阅读
- mysql - 尝试上传图像时 Multer 出错
- java - Android在某些部分显示文本语言错误
- python - 图像下载但字节大小为 0 同时响应是 respons.ok 和 python 中的 200 代码?
- amazon-ec2 - 由于 libssl.so.10 无法运行 yum:无法打开共享对象文件:没有这样的文件或目录
- python - 美丽的汤'find_all'功能似乎没有刮掉“find_all('div',class_ ='ais-infinite-hits ais-results-as-block')”
- python - 如何匹配重复项以及如果匹配如何删除python列表中的第二个?
- c - 这个表达式怎么读?
- javascript - 是否可以使用 java 或 javascript 检查移动应用程序是否安装?
- java - java - volatile 对象的字段不可见?
- wso2 - WSO2 ESB 在 payloadfactory 中添加斜线