首页 > 解决方案 > 您将如何编写计算量较小的等价物 numpy.where(np.ones(shape))

问题描述

我想获取给定形状数组的元素列表。我找到了一种简单的方法来做到这一点:

import numpy as np
shape = (3,3)
elements = np.where(np.ones(shape))

结果是

>>> elements
(array([0, 0, 0, 1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2, 0, 1, 2]))

这是预期的行为。但是,它似乎不是最高效的计算方式。f 形状很大,然后 np.where 可能会非常缓慢。我正在寻找一种计算效率更高的解决方案。任何的想法?

标签: pythonperformancenumpy

解决方案


根据收到的评论,我实施了 3 种方法来获得相同的结果并测试了它们的性能。

import timeit
import numpy as np

def with_where(a):
    shape = a.shape
    return np.where(np.ones(shape))
def with_mgrid(a):
    shape = a.shape
    grid_shape = (len(shape), np.prod(shape))
    return np.mgrid[0:shape[0],0:shape[1]].reshape(grid_shape)
def with_repeat(a):
    shape = a.shape
    np.repeat(np.arange(shape[0]), shape[1]), np.tile(np.arange(shape[1]), shape[0])

a1 = np.ones((1,1))
a10 = np.ones((10,10))
a100 = np.ones((100,100))
a1000 = np.ones((1000,1000))
a10000 = np.ones((10000,10000))

然后我在 ipython 中运行 %timeit

%timeit with_where(a1)
%timeit with_where(a10)
%timeit with_where(a100)
%timeit with_where(a1000)
%timeit with_where(a10000)

11.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.7 µs ± 39.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
146 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
16.2 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.49 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit with_mgrid(a1)
%timeit with_mgrid(a10)
%timeit with_mgrid(a100)
%timeit with_mgrid(a1000)
%timeit with_mgrid(a10000)

50.2 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
45.9 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
75.1 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.17 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.1 s ± 40.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit with_repeat(a1)
%timeit with_repeat(a10)
%timeit with_repeat(a100)
%timeit with_repeat(a1000)
%timeit with_repeat(a10000)

23.3 µs ± 931 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
31 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
4.41 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.05 s ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此对于大型数组,使用 np.where 的方法大约是最快方法的 2 倍。这并没有我想的那么糟糕。


推荐阅读