python - 了解 TensorFlow 检查点加载?
问题描述
TF 检查点中包含什么?例如,估计器存储一个包含GraphDef
原型的单独文件,您基本上可以执行 a tf.import_graph_def()
,然后创建 atf.train.Saver()
并将检查点恢复到图中。现在,如果您有另一个GraphDef
描述完全不同的图表,恰好共享完全相同的变量名称以及匹配的变量维度,您能否将检查点加载到该图表中?换句话说,它只是一个变量名到值映射,还是假设在加载过程中要检查的图有其他东西?如果您尝试将检查点加载到作为原始图的子集的图中(即张量维度和名称匹配,但缺少一些名称)怎么办?
解决方案
人们什么时候开始阅读文档(?): https ://www.tensorflow.org/mobile/prepare_models
这些是不同的概念。只要形状匹配,您就可以只加载权重。如果出现不匹配,您将得到:
从检查点恢复失败。这很可能是由于当前图与检查点的图不匹配造成的。请确保您没有根据检查点更改预期的图表。
但是,您可以调整一个不平凡的案例,其中图表完全不同:
import tensorflow as tf
import numpy as np
test_data = np.arange(4).reshape(1, 2, 2, 1)
# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output, {input: test_data}))
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
print(tf.trainable_variables())
# reset previous elements
tf.reset_default_graph()
# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")
print(sess.run(output, {input: test_data}))
就我而言,我得到了:
# 1st version (original graph)
[[[[-0. -0. -0. ]
[-0.08429337 -1.0156475 -0.42691123]]
[[-0.16858673 -2.031295 -0.85382247]
[-0.2528801 -3.0469427 -1.2807337 ]]]]
# 2nd version (altered graph)
[[[[-0.08429337 -1.0156475 -0.42691123]
[ 0.91570663 -0.01564753 0.57308877]]
[[ 1.9157066 0.98435247 1.5730888 ]
[ 2.9157066 1.9843525 2.5730886 ]]]]
推荐阅读
- python - Spark Dataframe Select 使用 Python 的列列表
- php - 在一条线上绘制点,在数组中的点之间
- c++ - 如果元素存在,则获取容器中元素的索引
- python - 如何在python中完成多处理部分后运行代码
- youtube-api - 获取有关 YouTube 上的实时流媒体的实时信息
- javascript - 有没有办法使用 chrome 开发者工具在我的 chrome 扩展中提供的屏幕截图 API?
- windows - 无法在主分区上设置恢复分区 ID
- python - 如何确保每个工作人员只使用一个 CPU?
- assembly - NASM - 创建文件或目录时如何设置权限
- php - 为什么在 Windows 10 中尝试使用 PHP CLI 删除(空)目录时出现“rmdir ... Permission denied:”错误?