首页 > 解决方案 > 对象作为 Python 多处理池的参数默认为相同的对象内存(错误)

问题描述

我正在尝试将对象作为参数传递给我想使用 python 多处理池并行执行每个函数的函数。但是,每个多处理函数调用只调用列表中的最后一个对象。

doSomething() 中的每个打印对象都使用相同的 Temp() 对象,并且具有相同的内存地址。输入的临时对象都是唯一的,但多处理池似乎只使用每个 doSomething() 函数调用的最后一个对象。每个 doSomething() 都有相同的 Temp() 对象。

为什么会这样?如何正确地将每个对象传递到多处理池中以并行执行?

def doSomething(temp):
    print(temp)

class Temp():
    def __init__(self, b):
        self.a = 10
        self.b = b

def main():

    # Create the input args into the multiprocess pool map function
    temps_args = ()

    bs = [-5, -10, -15]
    for b in bs:

        temp = Temp(b)
        print(temp)
        temps_args += (temp,)

    # Setup multi-procesing pool and execute multiprocessing cases
    pool = multiprocessing.Pool()
    res = pool.map(doSomething, temps_args)

if __name__ == '__main__':
    main()

标签: pythonmultiprocessing

解决方案


创建运行函数的进程时,它将根据父进程的当前状态从自己的内存开始。在多处理中,单个进程可能会被多次重用以减少启动时间。在 3 条记录的情况下,它们很可能都命中了相同的进程。

在继续之前:

当一个 Python 对象被垃圾回收时,该对象的 id 会重新进入循环。以此为例:

class Foo:
    pass


# All instances of Foo are kept in memory before ids are fetched
ids_for_alive = set(id(foo) for foo in [Foo() for _ in range(1000)])
print(len(ids_for_alive))
# prints 1000

# After every call to id, the last Foo instance reference is dropped and may be gc'd
ids_for_gc = set(id(Foo()) for _ in range(1000))
print(len(ids_for_gc))
# usually prints 1

对我来说,这只是1000在对象保存在内存中时打印,并且仅1在没有保存对对象的引用时打印。他们被允许收集垃圾,这意味着 id 保持可用,并且单个 id 被重复使用 1000 次。

回到多处理:

当工作人员被分配一项工作并接收其参数时,它们被加载到内存中并分配一个 id。当该工作结束时,对参数的引用,这意味着它可以被垃圾收集,并再次可用。例如:

class Foo:
    def __init__(self, val):
        self.val = val


def get_id(foo):
    print(foo.val)
    return id(foo)


if __name__ == '__main__':
    pool = multiprocessing.Pool()

    foos = [Foo(i) for i in range(3)]
    ids = set(pool.map(get_id, foos))

    print(f"Total ids: {len(ids)}")

对我来说,这是打印的:

0
1
2
Total ids: 1

现在,您可以做的一件有趣的事情(根本没有效率)是限制允许单个进程执行的作业数量。如果将此值设置为 1,则每次调用任务都会获得一个新进程,并且很可能会获得一个新的地址空间和一组 id 值。

if __name__ == '__main__':
    pool = multiprocessing.Pool(maxtasksperchild=1)

    foos = [Foo(i) for i in range(3)]
    ids = set(pool.map(get_id, foos))

    print(f"Total ids: {len(ids)}")

印刷:

0
1
2
Total ids: 3

希望这有助于澄清事情!


推荐阅读