首页 > 解决方案 > 在受监控的训练会话中数据集迭代器初始化的 TensorFlow 错误

问题描述

大家好,我需要一些帮助。我尝试在不使用估计器的情况下使用 tensorflow 对 resnet-101 imagenet 分类进行编码。我尝试用它来研究深度学习并了解如何使用 tensorflow。我的问题是 monitortrainingSession 没有初始化我的迭代器。

我已经阅读了一些关于这些问题的文章并尝试使用钩子来处理它,但它失败了,我不知道它为什么会失败。

在我创建 monitortrainingsession 之后,它首先初始化 train_iterator 并得到 outOfRange 异常,然后执行验证步骤。

现在看起来很好,但在完成运行验证并尝试再次运行训练步骤后。我得到了与 iterator.get_next() 相关的错误。它说我没有初始化迭代器,但我的钩子函数清楚地调用了

 session.run(self._initializer, feed_dict={filenames: self._filenames})

我确定,因为我可以看到我打印的以下消息以检查它是否已初始化。

iter_val.initializer after_create_session is called 0 times

我有什么问题?

运行流程如下

  1. 运行训练步骤很好(epoch = 0)

  2. 运行验证步骤正常(epoch = 0)

  3. 运行训练步骤错误(epoch = 1)

请忽略代码中的 horovod(hvd()) ,因为我现在没有使用它。

这是我的代码,所以请帮我修复它,让我知道我的代码有什么问题。

class _DatasetInitializerHook(tf.train.SessionRunHook):
    def __init__(self, initializer, filenames=[], name=""):
        self._initializer = initializer
        self._filenames = filenames
        self._name = name
        self._cnt = 0
        self._before_runCnt = 0

    def begin(self):
        pass

    def after_create_session(self, session, coord):
        del coord

        if len(self._filenames) == 0:
            session.run(self._initializer)
        else:
            session.run(self._initializer, feed_dict={filenames: self._filenames})
        print(self._name, "after_create_session is called {} times".format(self._cnt))
        self._cnt += 1


if __name__ == "__main__":
    if len(sys.argv) > 1:
        nlogs = sys.argv[1]
    else:
        nlogs = 0

    hvd.init()
    b_imagenet=False
    if b_imagenet:
        training_filenames = ['/data/tfrecords/imagenet2012_train_shard{}.tfrecord'.format(i) for i in range(129)]
    else:
        training_filenames = ['/data/cifar-10-tfrecords/train_shard{}.tfrecord'.format(i) for i in range(1, 2, 1)]

    filenames = tf.placeholder(tf.string, shape=[None])

    trainData = dataset_input_fn(is_training=True, filename=filenames, nworkers=hvd.size(), workeridx=hvd.rank(),
                                 batch_size=FLAGS.batchSize, prefetch_size=FLAGS.prefetch_buffer_size, repeat=1,
                                 shuffle_buffer_size=FLAGS.shuffle_buffer_size)

    valData = dataset_input_fn(is_training=False, filename=FLAGS.validationfile, nworkers=hvd.size(), workeridx=hvd.rank(),
                               batch_size=1,prefetch_size=FLAGS.prefetch_buffer_size, repeat=1, shuffle_buffer_size=1)
    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())

    for i in tqdm(range(FLAGS.nepoch)):
        shuffle(training_filenames)
        model = model_class(nCls=FLAGS.nClasses, img_width=FLAGS.width, img_height=FLAGS.height,
                        learning_rate=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay)

        iter_train = trainData.make_initializable_iterator()
        train_op = model.build_model(iter_train.get_next(), is_trainig=True, hvd=None)

        train_hooks = [hvd.BroadcastGlobalVariablesHook(0),
                       _DatasetInitializerHook(iter_train.initializer, training_filenames, "iter_train.initializer")]


        with tf.train.MonitoredTrainingSession(checkpoint_dir="./tmp/train_logs", config=config, hooks=train_hooks,
            save_checkpoint_secs=30) as sess:

            try:
                while True:
                    opt = sess.run([train_op])
            except tf.errors.OutOfRangeError:
              pass

        iter_val = valData.make_initializable_iterator()
        prediction_result = model.build_model(iter_val.get_next(),is_trainig=False, hvd=None)

        validation_hooks = [hvd.BroadcastGlobalVariablesHook(0),
                            _DatasetInitializerHook(iter_val.initializer, [], "iter_val.initializer")]


        with tf.train.MonitoredTrainingSession( checkpoint_dir="./tmp/train_logs",config=config, hooks=validation_hooks) as sess:
                try:
                    while True:
                       result = sess.run([prediction_result])
                except tf.errors.OutOfRangeError:
                   pass

这是我收到的错误消息。

tensorflow.python.framework.errors_impl.FailedPreconditionError: GetNext() 失败,因为迭代器尚未初始化。确保在获取下一个元素之前已为此迭代器运行了初始化程序操作。[[节点IteratorGetNext(定义在workspace/multi_gpu/main.py:128)]]

错误可能源于输入操作。连接到节点IteratorGetNext的输入源操作:IteratorV2_2(定义在workspace/multi_gpu/main.py:126)

标签: pythontensorflowiterator

解决方案


尝试将初始化程序放入脚手架:

scaffold = tf.train.Scaffold(local_init_op=train_init_operator)

并将其交给monitoredTrainingSessionwith:

with tf.train.MonitoredTrainingSession(scaffold=scaffold, ...

推荐阅读