首页 > 解决方案 > Python/Numpy:获取二维矩阵中的前 k 个最大值作为掩码

问题描述

假设我有一个像这样的 3x3 矩阵:

array([[8, 6, 3],
       [6, 7, 2],
       [0, 8, 9]])

现在我想获取矩阵中前 k 个最大值,并从中创建一个掩码。如果数字在前 k 中最大,则值为 1,否则为 0。令k=2. 在上面的例子中,有 19和 2 8,我们需要把它们都取走,所以返回的掩码是这样的:

array([[1, 0, 0],
       [0, 0, 0],
       [0, 1, 1]])

我已经阅读了这个那个答案,我可以使用索引作为掩码。但是,我想知道是否有更好的解决方案?

标签: pythonarraysnumpy

解决方案


这个怎么样?

def is_topk(a, k=1):
    _, rix = np.unique(-a, return_inverse=True)
    return np.where(rix < k, 1, 0).reshape(a.shape)

您的阵列上的示例:

>>> is_topk(a, 1)
array([[0, 0, 0],
       [0, 0, 0],
       [0, 0, 1]])

>>> is_topk(a, 2)
array([[1, 0, 0],
       [0, 0, 0],
       [0, 1, 1]])

推荐阅读