首页 > 解决方案 > 查找numpy数组中列最大值的索引但删除前一个最大值

问题描述

我有一个包含 N 行和 M 列的数组。

我想遍历所有列,找到包含该列最大值的行的索引。但是,每一行只能选择一次。

例如,让我们考虑一个矩阵

1 1
2 2

输出应该是[1, 0]. 因为第 1 行(值为 2)是第 0 列的最大值,那么我们移动到第 2 列,第 1 行不考虑,所以第 0 行将是最高的单元格。

事实上,使用嵌套的 for 循环可以轻松解决问题,例如:

removed_rows = []
for i in range (nb_columns):
  index_max = 0
  value_max = A[0,i]
  for j in range (nb_rows):
    if j in removed_rows:
      continue
    else:
      if value_max < A[j,i]:
        index_max = j
        value_max = A[j,i]
  removed_rows.append (index_max)

但是,对于一个巨大的矩阵来说,它似乎很慢。有什么方法可以让我们更快地做到这一点(使用 numpy?)?

非常感谢

标签: numpy

解决方案


这可能不是很快,因为它仍然循环遍历列,我认为由于约束是不可避免的,但应该比您的解决方案更快,因为它找到了最大值的索引argmax

out = []
mm = A.min() - 1

for j in range(A.shape[1]):
    idx = np.argmax(A[:,j])

    # replace the entire row with mm
    # so next `argmax` will ignore this row
    A[idx] = mm
    out.append(idx)
    

以上在 100 x 100 阵列上大约需要 640 us,在 1k x 1k 阵列上需要 18ms。您的代码拒绝在我的系统上的合理时间内在 1k x 1k 数组上运行。


推荐阅读