python - 在 numpy 中编写自定义归约函数
问题描述
我想要做什么:假设我有一维数组形式的数据。在拟合该数据后,(scipy.optimize.curve_fit),这将减少到一个 skaler/0D 数组。到目前为止,一切都很好。那是容易的部分。
问题是,数据实际上不是一维的,而是 (n+1)D。所以我将不得不在除一个之外的所有轴上迭代整个数组,取一个 1D 切片,拟合该切片并将其写入一个具有 n 维的新数组。为简单起见,我使用 sum 函数而不是在此示例代码中拟合。
def iter_columns(array: np.ndarray, axis=-1) -> np.ndarray:
"""
Reduce nd-data to (n-1)d data. By performing the operation on one axis.
:param array: Input data
:param axis: Axis to perform reduction over
:return: Array of reduced data
"""
reduced_shape = list(array.shape)
reduced_shape.pop(axis)
print(reduced_shape)
a = np.empty(tuple(reduced_shape))
print(a)
print(array)
with np.nditer(a, flags=['multi_index'], op_flags=[['writeonly']]) as it:
for a_i in it:
# modify multi index to slice over dimension of axis, append if axis
mod = list(it.multi_index)
mod.insert(axis if axis>=0 else len(mod),slice(None))
print(mod)
a_i[...] = b[tuple(mod)].sum()
return a
b = np.arange(10).reshape(5,2)
print(iter_columns(b, axis=-1))
虽然这看起来像它应该做的那样,但它看起来并不优雅。我尝试以其他方式使用 np.nditer,但我不明白如何告诉 nditer 加载块而不仅仅是单个数组条目。我也知道有一个 ufunc.reduce 函数正是为此,但我找不到关于如何构造它可以使用的函数的文档。
解决方案
对于一般的 python 函数,没有一种快速编译的方法来进行这种减少。无论迭代机制如何,您最终都必须func
为每组nD
值调用一次。
因为np.sum
您只需指定axis
. 这本质上是一个np.add.reduce
.
np.apply_along_axis
工作起来很像你的nditer
,除了它将切片维度移动到最后,使“插入”更容易。它用于ndindex
生成索引元组 - 但它也使用nditer
. 它的文档是错误的;它并不快。
一些比较时间:
In [223]: timeit iter_columns(b, axis=-1)
55.2 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [224]: timeit iter_columns(b, axis=-1)
54.6 µs ± 63.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [225]: timeit np.array([i.sum() for i in b])
40 µs ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [226]: timeit b.sum(-1)
6.86 µs ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
您的使用nditer
是正确的(通常人们会遇到问题),但正如您所见,它并没有提供任何速度优势。对于更高的维度,编写多个循环同样快。
处理更高维度的另一种方法是sum
最后移动轴,将其余部分重新整形为 1d,然后做一个简单的循环,然后重新整形。
另一种选择是编译你的函数和循环。 numba
是一个强大的工具。其中一个nditer
文档页面显示了如何nditer
在cython
.
推荐阅读
- java - 与广播接收器检查连接
- regex - 模式匹配后结果数组的 URL 路径的剩余部分
- ruby-on-rails - 有没有更简洁的方法从 Rails ActiveRecord 方法返回数组结果?
- android - SharedPreferences 在另一个类中返回空值
- typescript - 如何捕获静态成员异常?
- salesforce - INSUFFICIENT_ACCESS_OR_READONLY,对对象 ID 的访问权限不足:[]
- sql-server - SQL Server Management Studio 从 Oracle BadImageFormatException 导入数据
- javascript - 迭代对象
- javascript - Highcharts:使用相同数据更新多个图表上的系列
- rest - Angular 6 到 Spring 引导休息服务 CORS 问题