首页 > 解决方案 > 在 numpy 中编写自定义归约函数

问题描述

我想要做什么:假设我有一维数组形式的数据。在拟合该数据后,(scipy.optimize.curve_fit),这将减少到一个 skaler/0D 数组。到目前为止,一切都很好。那是容易的部分。

问题是,数据实际上不是一维的,而是 (n+1)D。所以我将不得不在除一个之外的所有轴上迭代整个数组,取一个 1D 切片,拟合该切片并将其写入一个具有 n 维的新数组。为简单起见,我使用 sum 函数而不是在此示例代码中拟合。

def iter_columns(array: np.ndarray, axis=-1) -> np.ndarray:
    """
    Reduce nd-data to (n-1)d data. By performing the operation on one axis.
    :param array: Input data
    :param axis: Axis to perform reduction over
    :return: Array of reduced data
    """
    reduced_shape = list(array.shape)
    reduced_shape.pop(axis)
    print(reduced_shape)
    a = np.empty(tuple(reduced_shape))
    print(a)
    print(array)
    with np.nditer(a, flags=['multi_index'], op_flags=[['writeonly']]) as it:
        for a_i in it:
            # modify multi index to slice over dimension of axis, append if axis
            mod = list(it.multi_index)
            mod.insert(axis if axis>=0 else len(mod),slice(None))
            print(mod)
            a_i[...] = b[tuple(mod)].sum()
    return a


b = np.arange(10).reshape(5,2)
print(iter_columns(b, axis=-1))

虽然这看起来像它应该做的那样,但它看起来并不优雅。我尝试以其他方式使用 np.nditer,但我不明白如何告诉 nditer 加载块而不仅仅是单个数组条目。我也知道有一个 ufunc.reduce 函数正是为此,但我找不到关于如何构造它可以使用的函数的文档。

标签: pythonnumpymultidimensional-arrayiterationslice

解决方案


对于一般的 python 函数,没有一种快速编译的方法来进行这种减少。无论迭代机制如何,您最终都必须func为每组nD值调用一次。

因为np.sum您只需指定axis. 这本质上是一个np.add.reduce.

np.apply_along_axis工作起来很像你的nditer,除了它将切片维度移动到最后,使“插入”更容易。它用于ndindex生成索引元组 - 但它也使用nditer. 它的文档是错误的;它并不快。

一些比较时间:

In [223]: timeit iter_columns(b, axis=-1)
55.2 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [224]: timeit iter_columns(b, axis=-1)
54.6 µs ± 63.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [225]: timeit np.array([i.sum() for i in b])
40 µs ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [226]: timeit b.sum(-1)
6.86 µs ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

您的使用nditer是正确的(通常人们会遇到问题),但正如您所见,它并没有提供任何速度优势。对于更高的维度,编写多个循环同样快。

处理更高维度的另一种方法是sum最后移动轴,将其余部分重新整形为 1d,然后做一个简单的循环,然后重新整形。

另一种选择是编译你的函数和循环。 numba是一个强大的工具。其中一个nditer文档页面显示了如何nditercython.


推荐阅读