python - 您将如何编写计算量较小的等价物 numpy.where(np.ones(shape))
问题描述
我想获取给定形状数组的元素列表。我找到了一种简单的方法来做到这一点:
import numpy as np
shape = (3,3)
elements = np.where(np.ones(shape))
结果是
>>> elements
(array([0, 0, 0, 1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2, 0, 1, 2]))
这是预期的行为。但是,它似乎不是最高效的计算方式。f 形状很大,然后 np.where 可能会非常缓慢。我正在寻找一种计算效率更高的解决方案。任何的想法?
解决方案
根据收到的评论,我实施了 3 种方法来获得相同的结果并测试了它们的性能。
import timeit
import numpy as np
def with_where(a):
shape = a.shape
return np.where(np.ones(shape))
def with_mgrid(a):
shape = a.shape
grid_shape = (len(shape), np.prod(shape))
return np.mgrid[0:shape[0],0:shape[1]].reshape(grid_shape)
def with_repeat(a):
shape = a.shape
np.repeat(np.arange(shape[0]), shape[1]), np.tile(np.arange(shape[1]), shape[0])
a1 = np.ones((1,1))
a10 = np.ones((10,10))
a100 = np.ones((100,100))
a1000 = np.ones((1000,1000))
a10000 = np.ones((10000,10000))
然后我在 ipython 中运行 %timeit
%timeit with_where(a1)
%timeit with_where(a10)
%timeit with_where(a100)
%timeit with_where(a1000)
%timeit with_where(a10000)
11.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.7 µs ± 39.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
146 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
16.2 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.49 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit with_mgrid(a1)
%timeit with_mgrid(a10)
%timeit with_mgrid(a100)
%timeit with_mgrid(a1000)
%timeit with_mgrid(a10000)
50.2 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
45.9 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
75.1 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.17 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.1 s ± 40.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit with_repeat(a1)
%timeit with_repeat(a10)
%timeit with_repeat(a100)
%timeit with_repeat(a1000)
%timeit with_repeat(a10000)
23.3 µs ± 931 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
31 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
4.41 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.05 s ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
因此对于大型数组,使用 np.where 的方法大约是最快方法的 2 倍。这并没有我想的那么糟糕。
推荐阅读
- php - php获取多维数组中的特定数组
- php - 如果您没有多少输入,如何制作表格?
- javascript - 如何在 Mocha 测试框架中按顺序同步调用异步函数?
- sql - 无法在查询中使用添加的列
- java - 视图中的 setBackground 无法应用
- node.js - 如何在 Visual Studio Code 中运行 npm 命令?
- python - 如何使用 Popen.communicate 输入几个参数?
- android - FireMonkey 对话框未全屏
- angular - InjectionToken vs Injectable
- python - 如何通过搜索返回 Numpy 数组的索引