首页 > 解决方案 > 在 2d numpy 数组的每行中查找 N 个最高值

问题描述

我有一个 2d numpy 数组a

a = np.array(range(0,25)).reshape(5,5)
---
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

我想找到每行的最大 N 值并将它们替换为 100。我以缓慢的方式执行此操作:

N = 2
idx = a.argsort()
for i in range(a.shape[0]):
    a[i,idx[i][::-1][0:N]] = 100
print(a)
---
[[  0   1   2 100 100]
 [  5   6   7 100 100]
 [ 10  11  12 100 100]
 [ 15  16  17 100 100]
 [ 20  21  22 100 100]]

实际上我的矩阵的形状是 6000*6000。如何以更好的方式做到这一点?喜欢申请?

标签: pythonnumpy

解决方案


你可以argpartition在这里使用:

N=2
ix = a.argpartition(-N)[:,-N:]
a[np.arange(a.shape[0])[:,None], ix] = 100

print(a)
array([[  0,   1,   2, 100, 100],
       [  5,   6,   7, 100, 100],
       [ 10,  11,  12, 100, 100],
       [ 15,  16,  17, 100, 100],
       [ 20,  21,  22, 100, 100]])

推荐阅读