python - tensorflow 仅从检查点恢复一些变量
问题描述
在检查了一个检查点(我们称之为模型 1)之后,我获得了下面的变量名列表(为简单起见缩短了):
var_list = ["ex1_model/fc2/b",
"ex1_model/fc2/b/Adam",
"ex1_model/fc2/b/Adam_1",
"ex1_model/fc2/w",
"ex1_model/fc2/w/Adam"]
假设我有一个更大的模型 2,并想用模型 1 的值初始化它的一部分。
从这里描述的名称中获取变量(因为我没有找到一种简单的方法):
def get_vars_by_name(names):
return [v for v in tf.global_variables() if v.name in names]
构建模型 2 和保护程序以恢复:
logits = build_model(inputs)
saver = tf.train.Saver(var_list=get_vars_by_name(var_list))
在
saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
我收到错误:
"ex1_model/fc2/w/Adam" [...] raise ValueError("No variables to save")
请帮我找出我犯的错误。我也很感激一种更简单的方法,因为这很糟糕。谢谢你。
解决方案
解决这个问题的一个简单方法是猜测是否应该恢复变量。
def ignore_name(name):
if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True
return False
您应该可以通过以下方式直接使用这个想法
def get_vars_by_name(names):
return [v for v in tf.global_variables() if v.name in names and not ignore_name(v.name)]
这甚至允许使用 ADAM 训练模型,然后切换到 SDG,反之亦然。
推荐阅读
- r - 如何绘制几何密度和直方图的函数?
- javascript - 如何在 React 的画布上绘制组件?
- java - 如何在连接到 Wi-Fi 的 Android 设备上获取 LAN 上的 IP 地址列表
- java - 如何在休眠中将父类向下转换为继承的子类?
- kubernetes - Azure Pipeline 中的 Kubernetes 回滚?
- flutter - BlocBuilder 中的 Flutter Slider 循环
- datadog - DataDog 事件是自动恢复的
- linux - Flutter 发布到 linux
- python - tape.gradient 返回的梯度在自定义训练循环中为无
- machine-learning - 如何将模型 (GLM) 从 h2o 移植到 scikit-learn?