首页 > 解决方案 > Tensorflow:将值输入现有占位符或将其删除

问题描述

我正在实现自定义 keras 回调,并且正在同一模型上执行两个连续的训练阶段。

在回调中,我创建了几个占位符来提供一些度量值,以便在训练结束时进行评估。对于第一个训练阶段这很好,因为占位符不存在,但是在第二个训练阶段这将导致错误,因为 tensorflow 将创建第二组占位符,但具有索引名称。

因此,我正在寻找一种解决方案,从第一个培训阶段将值输入占位符(可能是按名称查找占位符,然后将值输入其中)或按名称删除某些占位符以便我可以创建新的

编辑:

澄清我目前的情况。我实现了这个自定义的 Keras 回调(我将不考虑指标的计算):

class Metric(keras.callbacks.Callback):


def __init__(self):
    self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, name="prec")
    tf.summary.scalar("val_precision", self.val_prec_ph)

    self.merged = tf.summary.merge_all()
    self.writer = tf.summary.FileWriter(self.log_dir)

def on_train_begin(self, logs={}):

    self.precision = []

def on_train_end(self, logs={}):

   //do some calculations

    self.precision.append(calculation)

    summary = self.session.run(self.merged,
                               feed_dict={self.val_prec_ph: self.precision[-1]})

    self.writer.add_summary(summary)
    self.writer.flush()

这基本上是我做占位符的框架。由于连续运行,tensorflow 将执行以下操作:第一次训练将毫无问题地运行并将占位符命名为“prec”。然而,在第二次运行中,tensorflow 会将 self.val_prec_ph 占位符命名为“prec_”之类的名称,这将导致“prec”占位符尚未输入的错误,尽管它仍然存在。

因此,我要么想直接写入“prec”占位符,要么在第一次运行后将其删除,以免重复。

我之所以在最终培训过程结束时这样做是有原因的,yadda,yadda ...是另一个问题,这是一个不同的故事。

标签: tensorflow

解决方案


这是针对您的特定问题的可能解决方案,按名称(使用)在图表中搜索占位符,tf.Graph().get_tensor_by_name()如果找不到则创建它:

class Metric(keras.callbacks.Callback):

    def __init__(self, ph_name="prec"):
        try:
            self.val_prec_ph = tf.get_default_graph().get_tensor_by_name(
                ph_name + ':0')
            # Check this solution by @rvinas to cover possible suffix/scope errors:
            # https://stackoverflow.com/a/38935343/624547
        except KeyError:
            self.val_prec_ph = tf.placeholder(shape=(), dtype=tf.float64, 
                                              name=ph_name)

        tf.summary.scalar("val_precision", self.val_prec_ph)

        self.merged = tf.summary.merge_all()
        self.writer = tf.summary.FileWriter(self.log_dir)

    # ...

推荐阅读