python - 使用 dask 并行创建树
问题描述
我需要关于一个我很确定 dask 可以解决的问题的帮助。但我不知道如何解决它。
我需要递归地构造一棵树。
对于每个节点,如果满足一个标准,则进行计算(compute_val
),否则将创建 2 个新子节点。对孩子进行相同的治疗(build
)。然后,如果节点的所有子节点都执行了计算,我们可以继续进行合并(merge
)。合并可以执行子项的融合(如果它们都符合标准)或不执行任何操作。目前我只能并行化第一级,我不知道应该使用哪个 dask 工具更有效。这是我想要实现的简化的 MRE 序列:
import numpy as np
import time
class Node:
def __init__(self, level):
self.level = level
self.val = None
def merge(node, childs):
values = [child.val for child in childs]
if all(values) and sum(values)<0.1:
node.val = np.mean(values)
else:
node.childs = childs
return node
def compute_val():
time.sleep(0.1)
return np.random.rand(1)
def build(node):
print(node.level)
if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
node.val = compute_val()
else:
childs = [build(Node(level=node.level+1)) for _ in range(2)]
node = merge(node, childs)
return node
tree = build(Node(level=0))
解决方案
据我了解,处理递归(或任何动态计算)的方式是在任务中创建任务。
我正在尝试类似的东西,所以下面是我的 5 分钟说明性解决方案。您必须根据算法的特征对其进行优化。
请记住,任务会增加开销,因此您需要对计算进行分块以获得最佳结果。
相关文档:
接口参考:
- https://distributed.dask.org/en/latest/api.html#distributed.worker_client
- https://distributed.dask.org/en/latest/api.html#distributed.Client.gather
- https://distributed.dask.org/en/latest/api.html#distributed.Client.submit
import numpy as np
import time
from dask.distributed import Client, worker_client
# Create a dask client
# For convenience, I'm creating a localcluster.
client = Client(threads_per_worker=1, n_workers=8)
client
class Node:
def __init__(self, level):
self.level = level
self.val = None
self.childs = None # This was missing
def merge(node, childs):
values = [child.val for child in childs]
if all(values) and sum(values)<0.1:
node.val = np.mean(values)
else:
node.childs = childs
return node
def compute_val():
time.sleep(0.1) # Is this required.
return np.random.rand(1)
def build(node):
print(node.level)
if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
node.val = compute_val()
else:
with worker_client() as client:
child_futures = [client.submit(build, Node(level=node.level+1)) for _ in range(2)]
childs = client.gather(child_futures)
node = merge(node, childs)
return node
tree_future = client.submit(build, Node(level=0))
tree = tree_future.result()