首页 > 解决方案 > 在 dask.Array 任务图中嵌入计算前/计算后操作

问题描述

我有兴趣创建一个dask.array.Array在之前/之后打开和关闭资源的compute()。但是,我不想对最终用户将如何调用做任何假设,compute 并且我想避免创建自定义 daskArray子类或代理对象,所以我试图将操作嵌入到__dask_graph__底层数组中.

(旁白:请暂时忽略有关在 dask 中使用有状态对象的警告,我知道风险,这个问题只是关于任务图操作)。

考虑下面的类,它模拟了一个必须处于打开状态才能读取块的文件读取器,否则它会出现段错误。

import dask.array as da
import numpy as np

class FileReader:
    _open = True

    def open(self):
        self._open = True

    def close(self):
        self._open = False

    def to_dask(self) -> da.Array:
        return da.map_blocks(
            self._dask_block,
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )

    def _dask_block(self):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

如果文件保持打开状态,一切都很好,但如果关闭文件,则返回的 dask 数组to_dask将失败:

>>> t = FileReader()
>>> darr = t.to_dask()
>>> t.close()
>>> darr.compute()  # RuntimeError: Segfault!

当前的任务图如下所示:

>>> list(darr.dask)
[
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 0, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 1, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 2, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 3, 0, 0)
]

本质上,我想在开头添加新任务,该_dask_block层依赖于该层,并在末尾添加一个任务,该任务取决于_dask_block.

我尝试直接操纵HighLevelGraph输出da.map_blocks以手动添加这些任务,但发现它们在计算优化期间被修剪,因为darr.__dask_keys__()不包含我的密钥(而且,我想再次避免子类化或要求最终用户compute使用特殊优化标志调用)。

一种解决方案是确保_dask_block传递给 map_blocks 的函数始终打开和关闭底层资源......但是让我们假设打开/关闭过程相对较慢,有效地破坏了单机并行性的性能。所以我们只想要一个在开始时打开,并在结束时关闭。

我可以通过在我的调用中包含一个新密钥来稍微“作弊”以包含一个打开我的文件的任务,map_blocks如下所示:

    ...
    
    # new method that will run at beginning of compute()
    def _pre_compute(self):
        was_open = self._open
        if not was_open:
            self.open()
        return was_open

    def to_dask(self) -> da.Array:
        # new task key
        pre_task = 'pre_compute-' + tokenize(self._pre_compute)
        arr = da.map_blocks(
            self._dask_block,
            pre_task,  # add key here so all chunks depend on it
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )
        # add key to HighLevelGraph
        arr.dask.layers[pre_task] = {pre_task: (self._pre_compute,)}
        return da.Array(arr.dask, arr.name, arr.chunks, arr.dtype)

    # add "mock" argument to map_blocks function
    def _dask_block(self, _):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

到目前为止一切顺利,不再RuntimeError......但现在我已经泄露了文件句柄,因为最后没有关闭它。

那么我想要的是图表末尾的一个任务,它取决于pre_task(即是否必须为此计算打开文件)的输出,如果必须打开文件则关闭文件。

这就是我卡住的地方,因为该post-compute密钥将被优化器修剪......

有什么方法可以做到这一点,而无需创建自定义 Array 子类来覆盖 or 之类的方法__dask_postcompute____dask_keys__或者要求最终用户在没有优化的情况下调用计算?

标签: pythondask

解决方案


这是一个非常有趣的问题。我认为您在编辑任务图以包括打开和关闭共享资源的任务方面处于正确的轨道上。但是手动图形操作很繁琐,而且很难做到正确。

我认为完成你想要的最简单的方法是使用一些相对最近添加的实用程序来处理dask.graph_manipulation. 特别是,我认为我们想要bind,可用于向 Dask 集合添加隐式依赖项,并且wait_for,可用于确保集合的依赖项等待另一个不相关的集合。

我通过使用这些实用程序修改您的示例以创建各种to_dask()自动打开和关闭的示例:

import dask
import dask.array as da
import numpy as np
from dask.graph_manipulation import bind, checkpoint, wait_on


class FileReader:
    _open = False

    def open(self):
        self._open = True

    def close(self):
        self._open = False

    def to_dask(self) -> da.Array:
        # Delayed version of self.open
        @dask.delayed
        def open_resource():
            self.open()
        # Delayed version of self.close
        @dask.delayed
        def close_resource():
            self.close()
            
        opener = open_resource()
        arr = da.map_blocks(
            self._dask_block,
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )
        # Make sure the array is dependent on `opener`
        arr = bind(arr, opener)

        closer = close_resource()
        # Make sure the closer is dependent on the array being done
        closer = bind(closer, arr)
        # Make sure dependents of arr happen after `closer` is done
        arr, closer = wait_on(arr, closer)
        return arr

    def _dask_block(self):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

我发现在操作前后查看任务图很有趣。

之前,它是一个相对简单的分块数组:

chunked-array-with-no-checkpoints-or-dependencies

但是在操作之后,您可以看到数组块依赖于open_resource,然后这些块流入close_resource,然后流入让数组块进入更广阔的世界:

任务图可视化自开闭数组


推荐阅读