首页 > 解决方案 > Tensorflow: is there a way to load a pretrained model without having to redefine all the variables?

问题描述

I'm trying to split my code into different modules, one where the model is trained, another which analyzes the weights in the model.

When I save the model using

save_path = saver.save(sess, "checkpoints5/text8.ckpt")

It makes 4 files, ['checkpoint', 'text8.ckpt.data-00000-of-00001', 'text8.ckpt.meta', 'text8.ckpt.index']

I tried restoring this in the separate module using this code

train_graph = tf.Graph()
with train_graph.as_default():
    saver = tf.train.Saver()


with tf.Session(graph=train_graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint('MODEL4'))
    embed_mat = sess.run(embedding)

But I get this error message

ValueError                                Traceback (most recent call last)
<ipython-input-15-deaad9b67888> in <module>()
      1 train_graph = tf.Graph()
      2 with train_graph.as_default():
----> 3     saver = tf.train.Saver()
      4 
      5 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
   1309           time.time() + self._keep_checkpoint_every_n_hours * 3600)
   1310     elif not defer_build:
-> 1311       self.build()
   1312     if self.saver_def:
   1313       self._check_saver_def()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in build(self)
   1318     if context.executing_eagerly():
   1319       raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1320     self._build(self._filename, build_save=True, build_restore=True)
   1321 
   1322   def _build_eager(self, checkpoint_path, build_save, build_restore):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
   1343           return
   1344         else:
-> 1345           raise ValueError("No variables to save")
   1346       self._is_empty = False
   1347 

ValueError: No variables to save

After reading up on this issue, it seems that I need to redefine all the variables used when I trained the model.

Is there a way to access the weights without having to redefine everything? The weights are just numbers, surely there must be a way to access them directly?

标签: tensorflow

解决方案


对于仅访问检查点中的变量,请签出checkpoint_utils库。它提供了三个有用的 api 函数load_checkpointlist_variablesload_variable。我不确定是否有更好的方法,但您当然可以使用这些函数来提取检查点中所有变量的字典,如下所示:

import tensorflow as tf

ckpt = 'checkpoints5/text8.ckpt'
var_dict = {name: tf.train.load_checkpoint(ckpt).get_tensor(name)
            for name, _ in tf.train.list_variables(ckpt)}
print(var_dict)

要加载预训练模型而无需重新定义所有变量,您需要的不仅仅是检查点。检查点只有变量,它不知道如何恢复这些变量,即如何将它们映射到图表,没有实际的图表(和适当的地图)。SavedModel对于这种情况会更好。它可以保存模型MetaGraph和所有变量。恢复保存的模型时,您不必手动重新定义所有内容。以下代码是仅使用simple_save.

要保存经过训练的模型:

import tensorflow as tf

x = tf.placeholder(tf.float32)
y_ = tf.reshape(x, [-1, 1])
y_ = tf.layers.dense(y_, units=1)
loss = tf.losses.mean_squared_error(labels=x, predictions=y_)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for _ in range(10):
        sess.run(train_op, feed_dict={x: range(10)})
    # Let's check the bias here so that we can make sure
    # the model we restored later on is indeed our trained model here.
    d_b = sess.graph.get_tensor_by_name('dense/bias:0')
    print(sess.run(d_b))
    tf.saved_model.simple_save(sess, 'test', inputs={"x": x}, outputs={"y": y_})

要恢复保存的模型:

import tensorflow as tf

with tf.Session(graph=tf.Graph()) as sess:
    # A model saved by simple_save will be treated as a graph for inference / serving,
    # i.e. uses the tag tag_constants.SERVING
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'test')
    d_b = sess.graph.get_tensor_by_name('dense/bias:0')
    print(sess.run(d_b))

推荐阅读