python - 使用 dask 并行化数据迭代器
问题描述
我已经实现了一个数据迭代器,它从两个numpy
数组中获取对象,并在返回它们之前对它们进行非常密集的 CPU 计算。我想使用Dask
. 这是这个迭代器类的一个简单版本:
import numpy as np
class DataIterator:
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
item1, item2 = x[idx], y[idx]
# Do some very heavy computations here by
# calling other methods and return
return item1, item2
x = np.random.randint(20, size=(20,))
y = np.random.randint(50, size=(20,))
data_gen = DataIterator(x, y)
现在,我使用这样的简单 for 循环遍历项目:
for i, (item1, item2) in enumerate(data_gen):
print(item1, item2)
但这真的很慢。有人可以帮助使用 dask 并行化它吗?
解决方案
实现这一点的最简单方法是使用 dask.delayed 并装饰getitem方法。另一种方法是将 x, y 转换为 dask 数组,然后使用 dask.array 命令在 getitem 方法中进行计算。由于您没有提供繁重计算的详细信息,以下示例仅供参考。
Dask.delayed:
from dask import delayed
import numpy as np
class DataIterator:
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
@delayed
def __getitem__(self):
item1 = x.mean()
item2 = y.sum()
# Do some very heavy computations here by
# calling other methods and return
return item1, item2
x = np.random.randint(20, size=(20,))
y = np.random.randint(50, size=(20,))
data_gen = DataIterator(x, y)
x_mean, y_sum = data_gen.__getitem__().compute()
输出:
x_mean, y_sum
Out[41]: (8.45, 479)
Dask.array:
import dask.array as da
import numpy as np
class DataIterator:
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self):
item1 = x.mean()
item2 = y.sum()
# Do some very heavy computations here by
# calling other methods and return
return item1.compute(), item2.compute()
x = np.random.randint(20, size=(20,))
y = np.random.randint(50, size=(20,))
x = da.from_array(x, chunks = x.shape[0] // 4)
y = da.from_array(y, chunks = y.shape[0] // 4)
data_gen = DataIterator(x, y)
x_mean, y_sum = data_gen.__getitem__()
输出:
x_mean, y_sum
Out[50]: (10.4, 461)
推荐阅读
- php - Laravel "Storage::download()" 方法是如何工作的?
- ios - 如何使用 SwiftUI 在列表中的项目上实现长按手势?
- javascript - 如何每“t”分钟刷新一次 JWT 令牌?
- flutter - 如何防止 AppBar 在 Flutter 中裁剪其子项?
- typescript - 为什么 TS 中的泛型接口不能正确推断类型?
- sqlite - Heroku 不断删除我的聊天机器人的数据库,我该如何阻止它?
- r - 如何重新创建相关矩阵以运行 p.adjust?
- qt - qt 编译错误,包括 QtCore 导致 utf-8 错误
- android-studio - 未在 Android Studio 中安装 haxm
- mongodb - MongoDB如何使用方面计算未读消息