tensorflow - Tensorflow: is there a way to load a pretrained model without having to redefine all the variables?
问题描述
I'm trying to split my code into different modules, one where the model is trained, another which analyzes the weights in the model.
When I save the model using
save_path = saver.save(sess, "checkpoints5/text8.ckpt")
It makes 4 files, ['checkpoint', 'text8.ckpt.data-00000-of-00001', 'text8.ckpt.meta', 'text8.ckpt.index']
I tried restoring this in the separate module using this code
train_graph = tf.Graph()
with train_graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=train_graph) as sess:
saver.restore(sess, tf.train.latest_checkpoint('MODEL4'))
embed_mat = sess.run(embedding)
But I get this error message
ValueError Traceback (most recent call last)
<ipython-input-15-deaad9b67888> in <module>()
1 train_graph = tf.Graph()
2 with train_graph.as_default():
----> 3 saver = tf.train.Saver()
4
5
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
1309 time.time() + self._keep_checkpoint_every_n_hours * 3600)
1310 elif not defer_build:
-> 1311 self.build()
1312 if self.saver_def:
1313 self._check_saver_def()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in build(self)
1318 if context.executing_eagerly():
1319 raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1320 self._build(self._filename, build_save=True, build_restore=True)
1321
1322 def _build_eager(self, checkpoint_path, build_save, build_restore):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
1343 return
1344 else:
-> 1345 raise ValueError("No variables to save")
1346 self._is_empty = False
1347
ValueError: No variables to save
After reading up on this issue, it seems that I need to redefine all the variables used when I trained the model.
Is there a way to access the weights without having to redefine everything? The weights are just numbers, surely there must be a way to access them directly?
解决方案
对于仅访问检查点中的变量,请签出checkpoint_utils
库。它提供了三个有用的 api 函数load_checkpoint
:list_variables
和load_variable
。我不确定是否有更好的方法,但您当然可以使用这些函数来提取检查点中所有变量的字典,如下所示:
import tensorflow as tf
ckpt = 'checkpoints5/text8.ckpt'
var_dict = {name: tf.train.load_checkpoint(ckpt).get_tensor(name)
for name, _ in tf.train.list_variables(ckpt)}
print(var_dict)
要加载预训练模型而无需重新定义所有变量,您需要的不仅仅是检查点。检查点只有变量,它不知道如何恢复这些变量,即如何将它们映射到图表,没有实际的图表(和适当的地图)。SavedModel
对于这种情况会更好。它可以保存模型MetaGraph
和所有变量。恢复保存的模型时,您不必手动重新定义所有内容。以下代码是仅使用simple_save
.
要保存经过训练的模型:
import tensorflow as tf
x = tf.placeholder(tf.float32)
y_ = tf.reshape(x, [-1, 1])
y_ = tf.layers.dense(y_, units=1)
loss = tf.losses.mean_squared_error(labels=x, predictions=y_)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for _ in range(10):
sess.run(train_op, feed_dict={x: range(10)})
# Let's check the bias here so that we can make sure
# the model we restored later on is indeed our trained model here.
d_b = sess.graph.get_tensor_by_name('dense/bias:0')
print(sess.run(d_b))
tf.saved_model.simple_save(sess, 'test', inputs={"x": x}, outputs={"y": y_})
要恢复保存的模型:
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
# A model saved by simple_save will be treated as a graph for inference / serving,
# i.e. uses the tag tag_constants.SERVING
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'test')
d_b = sess.graph.get_tensor_by_name('dense/bias:0')
print(sess.run(d_b))
推荐阅读
- c++ - Stringstream 不包含初始数据
- linux - 在 Linux 中从 .txt 文件重定向到可执行文件
- amazon-web-services - 尽管没有数据,CloudWatch 警报仍处于 ALARM 状态
- python - Jitsi:通过 API(或 Python)创建房间?
- postgresql - 连接来自 3 个表的数据
- javascript - 在数组中找到两个数字,使它们加起来等于给定的数字
- google-sheets - IMPORTHTML/IMPORTXML 突然返回“未找到 URL 的资源”
- gradle - 使用 ml-gradle 加载 XQuery 文件时如何在 URI 中添加前缀
- bash - 如何根据特定列的某些特定值应用一些数学运算
- c# - 如何将字符串转换为整数?