首页 > 解决方案 > 在 python 多处理队列中推进 tensorflow 数据集迭代器

问题描述

有没有办法在这个例子中移动迭代器?

import tensorflow as tf
import numpy as np
from multiprocessing import Process, Queue

def store(batch, queue):
    while True:
        queue.put(batch)


if __name__=='__main__':
    pqueue = Queue()
    a1 = np.arange(1000)

    m = tf.data.Dataset.from_tensor_slices(a1).repeat().batch(1)
    iter_m = m.make_one_shot_iterator()
    m_init_ops = iter_m.make_initializer(m)
    next_m = iter_m.get_next()

    with tf.Session() as sess:
        batch = sess.run(next_m)
        pp_process = Process(target=store,args=(batch, pqueue,))
        pp_process.daemon = True
        pp_process.start()

        for i in range(10):
            print(pqueue.get())

我的想法是将处理后的数据存储在 tensorflow 可以访问的队列中进行训练,不幸的是我无法推进迭代器。任何建议将不胜感激。

当前输出为

[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]

标签: pythontensorflowtensorflow-datasets

解决方案


张量流多线程

迭代器没有前进,因为您在技术上只执行一次 get_next 操作:sess.run(next_m)。如果您只使用 tensorflow 多线程,则只需将其移入store函数即可获得所需的结果:

def store(sess, next_m, queue):
    while True:
        queue.put(sess.run(next_m))

# batch = sess.run(next_m) <- Remove
pp_process = Thread(target=store,args=(sess, next_m, pqueue,)) # <- Thread with correct args passed

TensorFlow 多处理

但是,对于多处理,您还应该确保在已经创建会话之后永远不会实例化(分叉)新进程,因为会话对象不可序列化。
在您的情况下,您可以简单地在 store 函数中创建一个新会话并在分叉后启动主会话:

from multiprocessing import Process, Queue

import numpy as np
import tensorflow as tf


def store(next_m, queue):
    with tf.Session() as sess:
        while True:
            queue.put(sess.run(next_m))


if __name__ == '__main__':
    ...
    pp_process = Process(target=store, args=(next_m, pqueue,))
    pp_process.daemon = True
    pp_process.start() # <- Fork before starting this session!

    with tf.Session() as sess:
        for i in range(10):
            print(pqueue.get())

推荐阅读