tensorflow - Tensorflow:将值输入现有占位符或将其删除
问题描述
我正在实现自定义 keras 回调,并且正在同一模型上执行两个连续的训练阶段。
在回调中,我创建了几个占位符来提供一些度量值,以便在训练结束时进行评估。对于第一个训练阶段这很好,因为占位符不存在,但是在第二个训练阶段这将导致错误,因为 tensorflow 将创建第二组占位符,但具有索引名称。
因此,我正在寻找一种解决方案,从第一个培训阶段将值输入占位符(可能是按名称查找占位符,然后将值输入其中)或按名称删除某些占位符以便我可以创建新的
编辑:
澄清我目前的情况。我实现了这个自定义的 Keras 回调(我将不考虑指标的计算):
class Metric(keras.callbacks.Callback):
def __init__(self):
self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, name="prec")
tf.summary.scalar("val_precision", self.val_prec_ph)
self.merged = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.log_dir)
def on_train_begin(self, logs={}):
self.precision = []
def on_train_end(self, logs={}):
//do some calculations
self.precision.append(calculation)
summary = self.session.run(self.merged,
feed_dict={self.val_prec_ph: self.precision[-1]})
self.writer.add_summary(summary)
self.writer.flush()
这基本上是我做占位符的框架。由于连续运行,tensorflow 将执行以下操作:第一次训练将毫无问题地运行并将占位符命名为“prec”。然而,在第二次运行中,tensorflow 会将 self.val_prec_ph 占位符命名为“prec_”之类的名称,这将导致“prec”占位符尚未输入的错误,尽管它仍然存在。
因此,我要么想直接写入“prec”占位符,要么在第一次运行后将其删除,以免重复。
我之所以在最终培训过程结束时这样做是有原因的,yadda,yadda ...是另一个问题,这是一个不同的故事。
解决方案
这是针对您的特定问题的可能解决方案,按名称(使用)在图表中搜索占位符,tf.Graph().get_tensor_by_name()
如果找不到则创建它:
class Metric(keras.callbacks.Callback):
def __init__(self, ph_name="prec"):
try:
self.val_prec_ph = tf.get_default_graph().get_tensor_by_name(
ph_name + ':0')
# Check this solution by @rvinas to cover possible suffix/scope errors:
# https://stackoverflow.com/a/38935343/624547
except KeyError:
self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64,
name=ph_name)
tf.summary.scalar("val_precision", self.val_prec_ph)
self.merged = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.log_dir)
# ...
推荐阅读
- python - 如何在python中以给定频率迭代时间戳
- python-3.x - 有人能告诉我如何纠正我的 python 代码中的语法错误和格式吗?我的输出一直说“SyntaxError:无效语法”
- r - 在 R 中,如何根据特定的行/列标准有选择地将一个单元格“复制并粘贴”到另一个单元格中?
- css - React Native 外部样式表不更新组件
- api - MarketStack API 仅返回 IEXG 交换数据
- kaggle - 从 kaggle 下载后文本文件减少(8GB 到 1GB)
- sql - SQL Server - 更改列长度会导致附带损害?
- python - 如何为每一行创建 HTML 表格按钮并将第一列的值输入到 Flask 函数进行查询
- function - 如何将表达式作为文字变量传递给 Julia 中的函数
- php - php 不写入 php 文件标志和 $variable