首页 > 解决方案 > 排序数组的 np.where 的更快替代方案

问题描述

给定一个沿每一行排序的大数组a,是否有比 numpy 更快的替代方法np.where来找到索引在哪里min_v <= a <= max_v?我想利用数组的排序特性应该能够加快速度。

np.where这是用于在大型数组中查找给定索引的设置示例。

import numpy as np

# Initialise an example of an array in which to search
r, c = int(1e2), int(1e6)
a = np.arange(r*c).reshape(r, c)

# Set up search limits
min_v = (r*c/2)-10
max_v = (r*c/2)+10

# Find indices of occurrences
idx = np.where(((a >= min_v) & (a <= max_v)))

标签: pythonarraysnumpywhere-clause

解决方案


您可以使用np.searchsorted

import numpy as np

r, c = 10, 100
a = np.arange(r*c).reshape(r, c)

min_v = ((r * c) // 2) - 10
max_v = ((r * c) // 2) + 10

# Old method
idx = np.where(((a >= min_v) & (a <= max_v)))

# With searchsorted
i1 = np.searchsorted(a.ravel(), min_v, 'left')
i2 = np.searchsorted(a.ravel(), max_v, 'right')
idx2 = np.unravel_index(np.arange(i1, i2), a.shape)
print((idx[0] == idx2[0]).all() and (idx[1] == idx2[1]).all())
# True

推荐阅读