python - 无论我指定哪个变量,Tensorflow Saver 都会恢复所有变量
问题描述
我正在尝试从 Tensorflow 图中保存和恢复变量的子集,以便丢弃我不需要的所有内容,并且它们的权重不会占用内存。将所需变量的列表或字典传递给的常见建议tf.train.Saver
不起作用:保护程序无论如何都会恢复所有变量。
一个最小的工作示例:
import os
import tensorflow as tf
sess = tf.Session()
with sess.as_default():
v1 = tf.get_variable("v1", [5, 5, 3])
v2 = tf.get_variable("v2", [5, 5, 3])
saver = tf.train.Saver([v2])
initializer2 = tf.variables_initializer([v1, v2])
sess.run(initializer2)
saver.save(sess, '/path/to/tf_model')
sess2 = tf.Session()
checkpoint = '/path/to/tf_model.meta'
saver.restore(sess2, tf.train.latest_checkpoint(os.path.dirname(checkpoint)))
with sess2.as_default(), sess2.graph.as_default():
loaded_vars = tf.trainable_variables()
print(loaded_vars)
输出
[<tf.Variable 'v1:0' shape=(5, 5, 3) dtype=float32_ref>,
<tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]
尽管如此,print(saver._var_list)
输出
[<tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]
这里有什么问题?
解决方案
这就是你想要做的。请仔细检查代码。
保存选定的变量
import tensorflow as tf
tf.reset_default_graph()
# =============================================================================
# to save
# =============================================================================
# create variables
v1 = tf.get_variable(name="v1", initializer=[5, 5, 3])
v2 = tf.get_variable(name="v2", initializer=[5, 5, 3])
# initialize variables
init_op = tf.global_variables_initializer()
# ops to save variable v2
saver = tf.train.Saver({"my_v2": v2})
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, './tf_vars/model.ckpt')
print("Model saved in file: %s" % save_path)
'Output':
Model saved in file: ./tf_vars/model.ckpt
恢复保存的变量
# =============================================================================
# to restore
# =============================================================================
# Create some variables.
v1 = tf.Variable(initial_value=[0, 0, 0], name="v1")
v2 = tf.Variable(initial_value=[0, 0, 0], name="v2")
# initialize variables
init_op = tf.global_variables_initializer()
# ops to restore variable v2.
saver = tf.train.Saver({"my_v2": v2})
with tf.Session() as sess:
sess.run(init_op)
# Restore variables from disk.
saver.restore(sess, './tf_vars/model.ckpt')
print("v1: %s" % v1.eval())
print("v2: %s" % v2.eval())
print("V2 variable restored.")
'Output':
v1: [0 0 0]
v2: [5 5 3]
V2 variable restored.
推荐阅读
- python - 无法从列表中删除值
- azure - 为 Azure Blob 生成预签名 URL 引发空指针异常
- php - 如何在事件侦听器中显示警报消息
- html - 如何对齐 checkbox.label 并水平输入
- python - Pandas 从 dict 将元组转换为索引创建数据框
- javascript - 为什么我更改了新数组中的值但我的默认数组也更改了
- python - Hangman 输出删除了错误的元素
- python - 从 dict 创建数据框,其中键是元组,值是列表
- sql - 如何使用@Query 注解实现全文搜索?
- kibana - 如何使用复选框更改 vega-lite 可视化中文本的背景和颜色?