首页 > 解决方案 > 如何直接总结使用python多处理获得的结果而不返回所有单个结果以节省内存?

问题描述

我有一个创建大掩码(布尔数组)的函数。我想多次调用此函数并创建一个形状相同的总掩码,该掩码在任何单个掩码中为真的索引处为真。

由于掩码的计算需要很长时间,我将其并行化,但该函数现在消耗了大量内存,因为我首先创建所有单独的掩码,然后将它们组合起来,这意味着我必须存储所有约 40.000 个单独的掩码。在使用多处理计算下一个掩码之前,是否有可能将返回的单个掩码直接添加到总掩码中?

这是该问题的示例代码:

import numpy as np
from multiprocessing import Pool


def return_something(seed):
    np.random.seed(seed)
    return np.random.choice([True, False], size=shape, p=[0.1, 0.9])


shape = (50, 50)
ncores = 4
seeds = np.random.randint(low=0, high=np.iinfo(np.int32).max, size=10)

# Without parallelisation, very slow:
mask = np.zeros(shape, dtype=bool)
for seed in seeds:
    mask |= return_something(seed)


# With parallelisation, takes too much memory
p = Pool(ncores)
mask_parallel = np.any(list(p.imap(return_something, seeds)), axis=0)

我想我对 (i)map 函数的理解不够。我知道 multiprocessing.imap 返回一个生成器,并且可以使用 tqdm 和以下代码显示例如进度条:

list(tqdm.tqdm(p.imap(fct, inputs), total=len(inputs))

由于进度条在多处理运行期间更新,我认为必须可以在运行期间访问结果并可能总结它们,但我不知道如何。

谢谢你的帮助!

标签: pythonpython-multiprocessing

解决方案


遍历种子没有意义,因为您正在创建一个非常大的数组 ech time in return_somethign。因此,您必须将此数组创建分割成一些子创建并遍历这些子创建。该Pool.map()方法返回每次迭代中执行函数的结果列表。向您展示针对您的案例的一般实施。我正在做的只是将每一行的创建并行化并通过map()函数将它们放在一起。

import numpy as np
import multiprocessing as mp

def return_something(i):
    mask = np.random.choice([True, False], size=(shape[0],), p=[0.1, 0.9])
    return mask

shape = (5000, 5000)

if __name__ == "__main__":
    pool = mp.Pool(mp.cpu_count())
    results = pool.map(return_something, [i for i in range(shape[1])])
    pool.close()
    print(len(results))

关于您的评论,我正在展示一种在计算结果项目后将其附加到列表的方法(即时)

import numpy as np
from multiprocessing import Pool
import time

def return_something(seed):
    np.random.seed(seed)
    return np.random.choice([True, False], size=shape, p=[0.1, 0.9])


shape = (50, 50)
ncores = 4
seeds = np.random.randint(low=0, high=np.iinfo(np.int32).max, size=100000)

mask = []

if __name__ == "__main__":
    p = Pool(12)
    start = time.time()
    for res in p.imap(return_something, seeds, chunksize=1):
        mask.append(res)
        print("{} (Time elapsed: {}s)".format(len(res), time.time() - start))

    p.close()
    print(len(mask))

推荐阅读