python - 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?
问题描述
我想在一组数据上反复训练我的张量流图,我想tf.estimator.inputs.numpy_input_fn
这可能就是我想要的。我发现批量大小、重复、时期和迭代器之间的区别令人难以置信,因此我开始尝试检查我的数据集的内容以试图弄清楚实际发生了什么。但是,每当我尝试这样做时,我的程序就会挂起。
这是我想出的最小的测试用例来重现这个:
import tensorflow as tf
import numpy
class TestMock(tf.test.TestCase):
def test(self):
inputs = numpy.array(range(10))
targets = numpy.array(range(10,20))
input_fn = tf.estimator.inputs.numpy_input_fn(
x=inputs,
y=targets,
batch_size=1,
num_epochs=2,
shuffle=False)
print input_fn()
with self.test_session() as sess:
# sess.run(input_fn()[0]) # it'll hang if I run this
pass
if __name__ == '__main__':
tf.test.main()
该程序输出
(<tf.Tensor 'fifo_queue_DequeueUpTo:1' shape=(?,) dtype=int64>, <tf.Tensor 'fifo_queue_DequeueUpTo:2' shape=(?,) dtype=int64>)
这似乎是合理的,但是一旦我尝试运行该sess.run
行,我的程序就会冻结,我必须终止该进程。我在这里做错了什么?
我想要做的是确保我输入到我的流程中的数据实际上是我认为的,但如果没有检查数据的能力,我认为我无法做到这一点。
解决方案
从上面的打印语句我们可以推断出input_fn
返回queue ops
,我们需要使用start_queue_runners
andCoordinator
来运行它们:
features_op, labels_op = input_fn()
with tf.Session() as sess:
# initialise and start the queues.
sess.run(tf.local_variables_initializer())
coordinator = tf.train.Coordinator()
_ = tf.train.start_queue_runners(coord=coordinator)
print(sess.run([features_op, labels_op]))
#[array([0]), array([10])]
推荐阅读
- python-3.x - 连接数据帧,重命名新索引,删除旧索引
- node.js - 拥有数千个虚拟用户的 Node.JS 负载测试工具
- openssl - nelem.h 在哪里?(对于 OpenSSL)
- javascript - 如何使用 Javascript 相对于时间更改 HTML 中文本的颜色?
- c# - HttpClient 不发送响应
- java - [Java][Spring-boot] @NotNull 注解不是 DDL
- regex - 我需要一个只接受这种格式 YYYYWWWW 的正则表达式(例如 2021WW12)
- html - Bootstrap 模态文本点击
- python - Selenium 在打开 Python 后自动关闭窗口
- unity3d - 总是让 SpatialAnchorManager 配置不正确