python - 即使从 ckpt 文件恢复后,Tensorflow 权重也没有恢复
问题描述
我正在使用以下代码,恢复的权重是随机初始化,而不是存储在 ckpt 文件中的实际权重。请帮助我了解我哪里出错了。
best_val_model = 'val_E1_A86.ckpt'
model_dir = './models/'
with tf.Session(config = config) as sess:
sess.run(tf.global_variables_initializer())
print('Testing the model on 10000 Images!')
ckpt_file = os.path.join(model_dir, best_val_model)
saver = tf.train.import_meta_graph(ckpt_file)
saver.restore(sess, ckpt_file)
weights = {}
for v in tf.trainable_variables():
weights[v.name] = v.eval()
tf.train.Saver(sess,filename)
在训练期间使用保存的实际模型权重。在恢复时,正在恢复随机权重。
解决方案
我只需要删除两行sess.run(tf.golbal_variables_initializer())
,saver = tf.train.import_meta_graph(ckpt_file)
它工作正常。使用的最终代码:
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session(config = config) as sess:
sess.run(tf.local_variables_initializer())
print('Testing the model on 10000 Images!')
ckpt_file = os.path.join(model_dir, best_val_model)
saver.restore(sess, ckpt_file)
推荐阅读
- python - 在 neuMF 模型中使用隐式数据集
- angular - catchError() 函数中 Angular 10 中的 HttpInterceptor 管道错误
- html - 使用 HTML 和 CSS 在所有浏览器中进行响应式网页设计
- amazon-web-services - 使用 NGINX 入口在 EKS 上公开服务和负载均衡器问题
- python - Django+Apache | ImportError:没有名为 django 的模块
- javascript - 如何避免在 reactjs 的状态数组中推送重复的对象..?
- javascript - 对这个 javascript 语法有点困惑 --> arr.indexOf(searchElement[, fromIndex])
- windows - Windows - VirtualHost 中的 Apache DocumentRoot 被忽略
- android - 执行长按时的圆形按钮背景和提示
- c++ - 字符串未在 a for 循环之外的 C++ 中打印