首页 > 解决方案 > 仅以对 dask 友好的方式基于其中一个数组对两个单独的 xarray DataArrays 进行排序

问题描述

假设我有两个 DataArraysAB,它们的维度都是时间、x、z。我只想对Ax 和 z 中的所有值进行排序。这样在每个单独的时间我都会有一个带有排序值的 DataArray。同时,我也想排序B,但基于A.

如果我只有一维 numpy 数组,我可以按照这个答案得到我想要的:

>>> a = numpy.array([2, 3, 1])
>>> b = numpy.array([4, 6, 7])
>>> p = a.argsort()
>>> p
[2, 0, 1]
>>> a[p]
array([1, 2, 3])
>>> b[p]
array([7, 4, 6])

但是,对于 DataArrays,问题会稍微复杂一些。我可以得到适用于以下代码的东西:

def zipsort_xarray(da_a, da_b, unsorted_dim="time"):
    assert da_a.dims == da_b.dims, "Dimensions aren't the same"
    for dim in da_a.dims:
        assert np.allclose(da_a[dim], da_b[dim]), f"Coordinates of {dim} aren't the same"

    sorted_dims = [ dim for dim in da_a.dims if dim != unsorted_dim ]
    daa_aux = da_a.stack(aux_dim=sorted_dims) # stack all dims to be sorted into one

    indices = np.argsort(daa_aux, axis=-1) # get indices that sort the last (stacked) dim
    indices[unsorted_dim] = range(len(indices.time)) # turn unsorted_dim into a counter
    flat_indices = np.concatenate(indices + indices.time*len(indices.aux_dim)) # Make indices appropriate for indexing a fully flattened version of the data array 

    daa_aux2 = da_a.stack(aux_dim2=da_a.dims) # get a fully flatten version of the data array
    daa_aux2.values = daa_aux2.values[flat_indices] # apply the flattened indices to sort it

    dab_aux2 = da_b.stack(aux_dim2=da_b.dims) # get a fully flatten version of the data array
    dab_aux2.values = dab_aux2.values[flat_indices] # apply the same flattened indices to sort it

    return daa_aux2.unstack(), dab_aux2.unstack() # return unflattened (unstacked) DataArrays



tsize=2
xsize=2
zsize=2

data1 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))
data2 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))
sort1, sort2 = zipsort_xarray(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))

但是,不仅我觉得这有点“hacky”,而且它也不适用于 dask。

我计划在将及时分块的大型 DataArrays 上使用它,因此重要的是我得到一些可以在这些情况下工作的东西。但是,如果我及时将 DataArrays 分块,我会得到:

data1 = data1.chunk(dict(time=1))
data2 = data2.chunk(dict(time=1))
sort1, sort2 = zipsort_xarray(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))

和输出

NotImplementedError: 'argsort' is not yet a valid method on dask arrays

有没有办法让这个工作与分块的 DataArrays 一起工作?

标签: pythonarraysnumpydaskpython-xarray

解决方案


我想我有一些似乎完全平行的工作。仅当时间维度以大小为 1 分块时才有效:

import xarray as xr
import numpy as np


def zipsort3(da_a, da_b, unsorted_dim="time"):
    """
    Only works if both `da_a` and `da_b` are chunked in `unsorted_dim`
    with size 1 chunks
    """
    from dask.array import map_blocks
    assert da_a.dims == da_b.dims, "Dimensions aren't the same"
    for dim in da_a.dims:
        assert np.allclose(da_a[dim], da_b[dim]), f"Coordinates of {dim} aren't the same"

    sorted_dims = [ dim for dim in da_a.dims if dim != unsorted_dim ]
    daa_aux = da_a.stack(aux_dim=sorted_dims).transpose(unsorted_dim, "aux_dim") # stack all dims to be sorted into one
    dab_aux = da_b.stack(aux_dim=sorted_dims).transpose(unsorted_dim, "aux_dim") # stack all dims to be sorted into one

    indices = map_blocks(np.argsort, daa_aux.data, axis=-1, dtype=np.int64)

    def reorder(A, ind): return A[0,ind]
    daa_aux.data = map_blocks(reorder, daa_aux.data, indices, dtype=np.float64)
    dab_aux.data = map_blocks(reorder, dab_aux.data, indices, dtype=np.float64)
    return daa_aux.unstack(), dab_aux.unstack()


tsize=2
xsize=2
zsize=2

data1 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))
data2 = xr.DataArray(np.random.randn(tsize, xsize, zsize), dims=("time", "x", "z"),
                     coords=dict(time=range(tsize),
                                 x=range(xsize),
                                 z=range(zsize)))

data1 = data1.chunk(dict(time=1))
data2 = data2.chunk(dict(time=1))

sorted1, sorted2 = zipsort3(data1.transpose("time", "z", "x"), data2.transpose("time", "z", "x"))

推荐阅读