python - 在 tensorflow 中恢复保存的模型时出现问题,如何调试?
问题描述
在 tensorflow 中训练模型后,保存如下:
saver = tf.train.Saver()
saver.save(sess,'myModel/Path/Model_1')
生成文件称为:
- Model_1.meta
- Model_1.index
- Model_1.data-000000-of-000001
- 检查点
现在要在创建新会话后恢复模型,并以与最初创建完全相同的方式初始化 tensorflow 图,我将其恢复如下:
sess = tf.Session()
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()
sess.run(init)
imported_meta = tf.train.Saver()
imported_meta.restore(sess,'myModel/Path/Model_1.meta')
这会引发以下错误:
InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [6152,32] rhs shape= [6164,80]
[[Node: save_2/Assign_3 = Assign[T=DT_FLOAT, _class=["loc:@DGNS/bidirectional_rnn/bw/basic_lstm_cell/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](DGNS/bidirectional_rnn/bw/basic_lstm_cell/kernel, save_2/RestoreV2/_111)]]
Caused by op u'save_2/Assign_3', defined at:
File "/usr/lib/python2.7/dist-packages/spyderlib/widgets /externalshell/start_ipython_kernel.py", line 205, in <module>
__ipythonkernel__.start()
"/usr/lib/python2.7/dist-packages/IPython/kernel/zmq/kernelapp.py", line 459, in start
ioloop.IOLoop.instance().start()
File "/usr/lib/python2.7/dist-packages/zmq/eventloop/ioloop.py", line 162, in start
super(ZMQIOLoop, self).start()
File "/usr/lib/python2.7/dist-packages/zmq/eventloop/minitornado/ioloop.py", line 830, in start
self._run_callback(callback)
File "/usr/lib/python2.7/dist-packages/zmq/eventloop/minitornado/ioloop.py", line 603, in _run_callback
ret = callback()
... ... ETC
我需要帮助了解这里发生的事情。该错误提示某些形状不匹配问题。但我不明白这是怎么回事,因为我使用完全相同的代码来生成模型和初始化新图形。代码中唯一的区别是模型加载部分。
如何开始调试此错误以获取有关如何正确加载模型的提示?
解决方案
我很确定您不应该加载 .meta 文件。很难理解,因为它为检查点输出 3 个不同的文件。试试这个:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(
'myModel/Path/Model_1.meta', clear_devices=True)
new_saver.restore(sess, 'myModel/Path/Model_1')
另外,为了澄清起见,您是将完整模型也存储在 .pb 文件中,还是仅生成这些检查点?
推荐阅读
- logrotate - logrotate 删除第二天的旧文件,这不应该发生
- cocoa - 打印机选项只提供纵向,没有切换选项
- amazon-web-services - AWS s3 文件覆盖不适用于 crontab
- php - count():参数必须是数组或者是laravel中实现Countable的对象
- python - 是否有另一种方法可以顺序单击列表中的所有 Web 元素?
- angular - Angular Progressive Web App(PWA)离线不起作用
- python - 当您不知道总页数时,有没有办法使用多处理进行 api 调用
- jestjs - 在 puppeteer 中模拟不同的 window.location
- reactjs - 动态导入图片 react
- python-3.x - 我可以在 Windows 操作系统中使用带有纯 python3 的 RDkit