首页 > 解决方案 > 了解 TensorFlow 检查点加载?

问题描述

TF 检查点中包含什么?例如,估计器存储一个包含GraphDef原型的单独文件,您基本上可以执行 a tf.import_graph_def(),然后创建 atf.train.Saver()并将检查点恢复到图中。现在,如果您有另一个GraphDef描述完全不同的图表,恰好共享完全相同的变量名称以及匹配的变量维度,您能否将检查点加载到该图表中?换句话说,它只是一个变量名到值映射,还是假设在加载过程中要检查的图有其他东西?如果您尝试将检查点加载到作为原始图的子集的图中(即张量维度和名称匹配,但缺少一些名称)怎么办?

标签: pythontensorflow

解决方案


人们什么时候开始阅读文档(?): https ://www.tensorflow.org/mobile/prepare_models

这些是不同的概念。只要形状匹配,您就可以只加载权重。如果出现不匹配,您将得到:

从检查点恢复失败。这很可能是由于当前图与检查点的图不匹配造成的。请确保您没有根据检查点更改预期的图表。

但是,您可以调整一个不平凡的案例,其中图表完全不同:

import tensorflow as tf
import numpy as np

test_data = np.arange(4).reshape(1, 2, 2, 1)

# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(output, {input: test_data}))
  saver = tf.train.Saver()
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print(tf.trainable_variables())

# reset previous elements
tf.reset_default_graph()

# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.restore(sess, "/tmp/model.ckpt")
  print(sess.run(output, {input: test_data}))

就我而言,我得到了:

# 1st version (original graph)
[[[[-0.         -0.         -0.        ]
   [-0.08429337 -1.0156475  -0.42691123]]

  [[-0.16858673 -2.031295   -0.85382247]
   [-0.2528801  -3.0469427  -1.2807337 ]]]]
# 2nd version (altered graph)
[[[[-0.08429337 -1.0156475  -0.42691123]
   [ 0.91570663 -0.01564753  0.57308877]]

  [[ 1.9157066   0.98435247  1.5730888 ]
   [ 2.9157066   1.9843525   2.5730886 ]]]]

推荐阅读