首页 > 解决方案 > 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?

问题描述

我想在一组数据上反复训练我的张量流图,我想tf.estimator.inputs.numpy_input_fn这可能就是我想要的。我发现批量大小、重复、时期和迭代器之间的区别令人难以置信,因此我开始尝试检查我的数据集的内容以试图弄清楚实际发生了什么。但是,每当我尝试这样做时,我的程序就会挂起。

这是我想出的最小的测试用例来重现这个:

import tensorflow as tf
import numpy

class TestMock(tf.test.TestCase):
    def test(self):
        inputs = numpy.array(range(10))
        targets = numpy.array(range(10,20))

        input_fn = tf.estimator.inputs.numpy_input_fn(
            x=inputs,
            y=targets,
            batch_size=1,
            num_epochs=2,
            shuffle=False)

        print input_fn()
        with self.test_session() as sess:
            # sess.run(input_fn()[0]) # it'll hang if I run this
            pass

if __name__ == '__main__':
    tf.test.main()

该程序输出

(<tf.Tensor 'fifo_queue_DequeueUpTo:1' shape=(?,) dtype=int64>, <tf.Tensor 'fifo_queue_DequeueUpTo:2' shape=(?,) dtype=int64>)

这似乎是合理的,但是一旦我尝试运行该sess.run行,我的程序就会冻结,我必须终止该进程。我在这里做错了什么?

我想要做的是确保我输入到我的流程中的数据实际上是我认为的,但如果没有检查数据的能力,我认为我无法做到这一点。

标签: pythontensorflow

解决方案


从上面的打印语句我们可以推断出input_fn返回queue ops,我们需要使用start_queue_runnersandCoordinator来运行它们:

 features_op, labels_op = input_fn()
 with tf.Session() as sess:
     # initialise and start the queues.
     sess.run(tf.local_variables_initializer())

     coordinator = tf.train.Coordinator()
     _ = tf.train.start_queue_runners(coord=coordinator)

    print(sess.run([features_op, labels_op]))

    #[array([0]), array([10])]

推荐阅读