首页 > 解决方案 > tensorflow 仅从检查点恢复一些变量

问题描述

在检查了一个检查点(我们称之为模型 1)之后,我获得了下面的变量名列表(为简单起见缩短了):

var_list = ["ex1_model/fc2/b",
"ex1_model/fc2/b/Adam",
"ex1_model/fc2/b/Adam_1",
"ex1_model/fc2/w",
"ex1_model/fc2/w/Adam"]

假设我有一个更大的模型 2,并想用模型 1 的值初始化它的一部分。

从这里描述的名称中获取变量(因为我没有找到一种简单的方法):

def get_vars_by_name(names):
    return [v for v in tf.global_variables() if v.name in names]

构建模型 2 和保护程序以恢复:

logits = build_model(inputs)
saver = tf.train.Saver(var_list=get_vars_by_name(var_list))

saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))

我收到错误:

"ex1_model/fc2/w/Adam" [...] raise ValueError("No variables to save")

请帮我找出我犯的错误。我也很感激一种更简单的方法,因为这很糟糕。谢谢你。

标签: pythontensorflowdeep-learning

解决方案


解决这个问题的一个简单方法是猜测是否应该恢复变量。

def ignore_name(name):
    if name.endswith('/Adam') or name.endswith('/Adam_1'):
        return True
    return False

您应该可以通过以下方式直接使用这个想法

def get_vars_by_name(names):
    return [v for v in tf.global_variables() if v.name in names and not ignore_name(v.name)]

这甚至允许使用 ADAM 训练模型,然后切换到 SDG,反之亦然。


推荐阅读