首页 > 解决方案 > 无论我指定哪个变量,Tensorflow Saver 都会恢复所有变量

问题描述

我正在尝试从 Tensorflow 图中保存和恢复变量的子集,以便丢弃我不需要的所有内容,并且它们的权重不会占用内存。将所需变量的列表或字典传递给的常见建议tf.train.Saver不起作用:保护程序无论如何都会恢复所有变量。

一个最小的工作示例:

import os
import tensorflow as tf
sess = tf.Session()
with sess.as_default():
    v1 =  tf.get_variable("v1", [5, 5, 3])
    v2 =  tf.get_variable("v2", [5, 5, 3])
    saver = tf.train.Saver([v2])
    initializer2 = tf.variables_initializer([v1, v2])
    sess.run(initializer2)
saver.save(sess, '/path/to/tf_model')

sess2 = tf.Session()
checkpoint = '/path/to/tf_model.meta'
saver.restore(sess2, tf.train.latest_checkpoint(os.path.dirname(checkpoint)))

with sess2.as_default(), sess2.graph.as_default():
    loaded_vars = tf.trainable_variables()

print(loaded_vars)

输出

[<tf.Variable 'v1:0' shape=(5, 5, 3) dtype=float32_ref>,
 <tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]

尽管如此,print(saver._var_list)输出

[<tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]

这里有什么问题?

标签: pythontensorflow

解决方案


这就是你想要做的。请仔细检查代码。

保存选定的变量

import tensorflow as tf

tf.reset_default_graph()

# =============================================================================
# to save
# =============================================================================

# create variables
v1 =  tf.get_variable(name="v1", initializer=[5, 5, 3])
v2 =  tf.get_variable(name="v2", initializer=[5, 5, 3])

# initialize variables
init_op = tf.global_variables_initializer()

# ops to save variable v2
saver = tf.train.Saver({"my_v2": v2})

with tf.Session() as sess:
    sess.run(init_op)
    save_path = saver.save(sess, './tf_vars/model.ckpt')
    print("Model saved in file: %s" % save_path)

'Output':
Model saved in file: ./tf_vars/model.ckpt

恢复保存的变量

# =============================================================================
# to restore
# =============================================================================

# Create some variables.
v1 = tf.Variable(initial_value=[0, 0, 0], name="v1")
v2 = tf.Variable(initial_value=[0, 0, 0], name="v2")

# initialize variables
init_op = tf.global_variables_initializer()

#  ops to restore variable v2.
saver = tf.train.Saver({"my_v2": v2})

with tf.Session() as sess:
    sess.run(init_op)
    # Restore variables from disk.
    saver.restore(sess, './tf_vars/model.ckpt')
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())

    print("V2 variable restored.")

'Output':
v1: [0 0 0]
v2: [5 5 3]
V2 variable restored.

推荐阅读