首页 > 解决方案 > 使用 dask 并行创建树

问题描述

我需要关于一个我很确定 dask 可以解决的问题的帮助。但我不知道如何解决它。

我需要递归地构造一棵树。

对于每个节点,如果满足一个标准,则进行计算(compute_val),否则将创建 2 个新子节点。对孩子进行相同的治疗(build)。然后,如果节点的所有子节点都执行了计算,我们可以继续进行合并(merge)。合并可以执行子项的融合(如果它们都符合标准)或不执行任何操作。目前我只能并行化第一级,我不知道应该使用哪个 dask 工具更有效。这是我想要实现的简化的 MRE 序列:

import numpy as np
import time

class Node:
    def __init__(self, level):
        self.level = level
        self.val = None

def merge(node, childs):
    values = [child.val for child in childs]
    if all(values) and sum(values)<0.1:
        node.val = np.mean(values)
    else:
        node.childs = childs
    return node        

def compute_val():
    time.sleep(0.1)
    return np.random.rand(1)

def build(node):
    print(node.level)
    if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
        node.val = compute_val()
    else:
        childs = [build(Node(level=node.level+1)) for _ in range(2)]
        node = merge(node, childs)
    return node

tree = build(Node(level=0))

标签: pythonrecursiondask

解决方案


据我了解,处理递归(或任何动态计算)的方式是在任务中创建任务。

我正在尝试类似的东西,所以下面是我的 5 分钟说明性解决方案。您必须根据算法的特征对其进行优化。

请记住,任务会增加开销,因此您需要对计算进行分块以获得最佳结果。

相关文档:

接口参考:

import numpy as np
import time
from dask.distributed import Client, worker_client

# Create a dask client
# For convenience, I'm creating a localcluster.
client = Client(threads_per_worker=1, n_workers=8)
client

class Node:
    def __init__(self, level):
        self.level = level
        self.val = None
        self.childs = None   # This was missing

def merge(node, childs):
    values = [child.val for child in childs]
    if all(values) and sum(values)<0.1:
        node.val = np.mean(values)
    else:
        node.childs = childs
    return node        

def compute_val():
    time.sleep(0.1)            # Is this required.
    return np.random.rand(1)

def build(node):
    print(node.level)
    if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
        node.val = compute_val()
    else:
        with worker_client() as client:
            child_futures = [client.submit(build, Node(level=node.level+1)) for _ in range(2)]
            childs = client.gather(child_futures)
        node = merge(node, childs)
    return node

tree_future = client.submit(build, Node(level=0))
tree = tree_future.result()

推荐阅读