首页 > 解决方案 > 使用 dask 并行化数据迭代器

问题描述

我已经实现了一个数据迭代器,它从两个numpy数组中获取对象,并在返回它们之前对它们进行非常密集的 CPU 计算。我想使用Dask. 这是这个迭代器类的一个简单版本:

import numpy as np

class DataIterator:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        item1, item2 = x[idx], y[idx]
        # Do some very heavy computations here by
        # calling other methods and return  
        return item1, item2

x = np.random.randint(20, size=(20,))
y = np.random.randint(50, size=(20,))

data_gen = DataIterator(x, y)

现在,我使用这样的简单 for 循环遍历项目:

for i, (item1, item2) in enumerate(data_gen):
    print(item1, item2)

但这真的很慢。有人可以帮助使用 dask 并行化它吗?

标签: pythonpython-3.xparallel-processingmultiprocessingdask

解决方案


实现这一点的最简单方法是使用 dask.delayed 并装饰getitem方法。另一种方法是将 x, y 转换为 dask 数组,然后使用 dask.array 命令在 getitem 方法中进行计算。由于您没有提供繁重计算的详细信息,以下示例仅供参考。

Dask.delayed:

from dask import delayed
import numpy as np

class DataIterator:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    @delayed
    def __getitem__(self):
        item1 = x.mean()
        item2 = y.sum()
        # Do some very heavy computations here by
        # calling other methods and return  
        return item1, item2

x = np.random.randint(20, size=(20,))
y = np.random.randint(50, size=(20,))

data_gen = DataIterator(x, y)
x_mean, y_sum = data_gen.__getitem__().compute()

输出:

x_mean, y_sum
Out[41]: (8.45, 479)

Dask.array:

import dask.array as da
import numpy as np

class DataIterator:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self):
        item1 = x.mean()
        item2 = y.sum()
        # Do some very heavy computations here by
        # calling other methods and return  
        return item1.compute(), item2.compute()

x = np.random.randint(20, size=(20,))
y = np.random.randint(50, size=(20,))

x = da.from_array(x, chunks = x.shape[0] // 4)
y = da.from_array(y, chunks = y.shape[0] // 4)

data_gen = DataIterator(x, y)
x_mean, y_sum = data_gen.__getitem__()

输出:

x_mean, y_sum
Out[50]: (10.4, 461)

推荐阅读