首页 > 解决方案 > 如何使用 numpy 找到每行的最小元素的 n 个索引?

问题描述

例如:

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

结果,我想要:

[ [0, 0, 2],
  [1, 0, 1],
  [2, 1, 2] ]
            

每行的第一个数字就是 p1 中的行号。剩余的 n 个数字是该行的最小元素的索引。所以:

[0, 0, 2]
 # 0 is the index of the first row in p1.
 # (0, 2 are the indices of minimum elements of the row)


[1, 0, 1]
# 1 is the index of the second row in p1.
# (0, 1 are the indices of minimum elements of the row)

[2, 1, 2]
# 2 is the index of the third row in p1.
# (1, 2 are the indices of minimum elements of the row)

非常感谢!!!

标签: pythonnumpy

解决方案


用于np.argpartition查找前两个最小值:

import numpy as np

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

pos = np.argpartition(p1, axis=1, kth=2)

res = np.hstack([np.arange(3)[:, None], np.sort(pos[:, :2])])
print(res)

输出

[[0 0 2]
 [1 0 1]
 [2 1 2]]

一旦找到np.hstack用于连接行索引的最小值并与之相连。


推荐阅读