首页 > 解决方案 > 如何在分布式 Tensorflow 中并行化 python 输入管道

问题描述

我有一个非平凡的输入管道,它包括读取地面实况和原始数据并对它们执行预处理,用 Python 编写。为单个样本运行输入管道需要很长时间,因此我有多个进程(来自 python 多处理包)并行运行并排队以快速执行操作并预取数据。然后使用 feed_dict 将输出馈送到我的网络。在我的训练循环中,这个过程的开销比实际的 tf.Session.run() 时间少 2 个数量级。我试图通过用 tf.py_func 包装我的 read+preprocess 函数来转移到 tf.data API,但它运行缓慢,可能是由于 GIL,即使增加了多个调用的数量也是如此。我想将我的训练扩展到多台机器,但不确定在这种情况下数据获取的行为如何,还有

所以,基本上我的问题是:如何在多个 CPU 内核上并行运行 tf.data api 输入管道中的 python 函数?

标签: pythontensorflowtensorflow-datasets

解决方案


一些澄清,tf.py_func可以与您的并行运行sess.run()(因为sess.run()发布了 GIL),但您不能tf.py_func在同一个 python 进程中运行多个。

在这种情况下,通常的答案是离线进行预处理,将结果保存在磁盘上(例如使用 TFRecord 格式),在训练期间从文件中读取准备好的数据。您可能可以使用多处理之类的东西并行化离线预处理。

如果您可以使用 tf 操作来表达您的预处理,您可以使用 并行运行它Dataset.map,但在tf.data. 如果上述方法由于某种原因不起作用,您可能必须自己连接多处理。

解决此问题的一种方法如下。让多个进程产生您的输入,将它们放入 multiprocessing.Queue (或共享内存并围绕它进行一些锁定)。使用生成器函数实现接收端并使用from_generator创建数据集。


推荐阅读