首页 > 解决方案 > 如何获取python多维np数组中所有最大值的索引

问题描述

我想从 ndarray 中获取每一行的最大值的所有索引。例如,我有

arr = np.array([[1, 3, 3], [1, 5, 4]])

我想从第一行获取所有 3 的索引,从第二行获取所有 5 的索引。

np.where(((arr == arr[0].max()) | (arr == arr[1].max())))

它返回

(array([0, 0, 1], dtype=int64), array([1, 2, 1], dtype=int64))

我想要类似的东西,但对于任何数量的行都更通用。因为np.where(arr == arr.argmax())不像我想要的那样工作。它只返回它在每一行中遇到的第一个最大值的索引。

标签: pythonarraysnumpy

解决方案


@Paul 在评论中的回答可能是你能找到的最好的答案。写给读者看。arr.max(1)在每一行中找到最大值,并在每一行arr==arr.max(1,keepdims=True)中找到等于该行中对应最大值的所有元素。最后nonzero返回这些元素的索引:

(arr==arr.max(axis=1,keepdims=True)).nonzero()

OP示例的输出:

(array([0, 0, 1]), array([1, 2, 1]))

推荐阅读