首页 > 解决方案 > 我可以在会话之外或在创建会话之间有挂钩/回调吗?

问题描述

可以执行一个计划,根据train_and_evaluate()我传递的规范训练和评估模型。我可以向EvalSpecTrainSpec注册一些钩子,但有限制。

问题是我只能有一个SessionRunHook它将作为回调工作,但总是只会话中。

我的问题是我有一个更复杂的时间表。在评估期间,我还想量化模型并进一步评估该量化模型。如前所述,这里的问题是,如果我使用-like 对象,我总是在会话中。SessionRunHook

所以问题是是否有一种方法可以train_and_evaluate()在两者之间使用和注册一些回调:

train_and_evaluate(..):

  # .. deep down ..

  while <condition>:
    with tf.Session() as train_sess:
      # Do training ..

    if the_callback_i_want:
      the_callback_i_want()

    with tf.Session() as eval_sess:
      # Do evaluation ..

这可能吗?

标签: pythontensorflow

解决方案


我想你可以实现begin你自己的SessionHook子类的方法。

为了示例,我使用了iris 代码(请参阅此文档)。

import tensorflow as tf

def the_callback_i_want():
    # You need to work in a new graph so let's create a new one
    g = tf.Graph()
    with g.as_default():
        x = tf.get_variable("x", ())
        x = tf.assign_add(x, 1)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:   
            sess.run(init)
            print("I'm here !", sess.run(x))


class MyHook(tf.train.SessionRunHook):

  def begin(self):
    """Called once before using the session.

    When called, the default graph is the one that will be launched in the
    session.  The hook can modify the graph by adding new operations to it.
    After the `begin()` call the graph will be finalized and the other callbacks
    can not modify the graph anymore. Second call of `begin()` on the same
    graph, should not change the graph.
    """
    the_callback_i_want()


import iris_data
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns, hidden_units=[10, 10],  n_classes=3)

# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

train_spec = tf.estimator.TrainSpec(input_fn=lambda:iris_data.train_input_fn(train_x, train_y,
                                                 10), max_steps=100)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,
                                                10), hooks=[MyHook()])
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)

它打印:

INFO:tensorflow:Saving checkpoints for 100 into /var/folders/***/model.ckpt.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-10-18-20:19:28
I'm here ! 1.9386581
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/***/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-10-18-20:19:28

推荐阅读