首页 > 解决方案 > 有没有更快的方法来屏蔽数组?

问题描述

我有一个 numpy 数组,我需要屏蔽它。我的功能如下所示:

def mask_arr(arr, min, max):
    for i in range(arr.size-1):
        if arr[i] < min:
            arr[i] = 0
        elif arr[i] > max:
            arr[i] = 1
        else:
            arr[i] = 10

问题是阵列很大,掩盖它需要很长时间。我怎样才能获得相同的结果但更快?

标签: pythonperformancenumpy

解决方案


您可以使用嵌套np.where,如下所示:

import numpy as np
q = np.random.rand(4,4)
# array([[0.86305369, 0.88477713, 0.58776518, 0.69122533],
#   [0.52591559, 0.33155238, 0.50139987, 0.66812239],
#   [0.83240284, 0.70147098, 0.17118681, 0.59652636],
#   [0.82031661, 0.32032657, 0.55088698, 0.28931661]])
np.where(q > 0.8, 1, np.where(q < 0.3, 0, 10))
# array([[ 1,  1, 10, 10],
#   [10, 10, 10, 10],
#   [ 1, 10,  0, 10],
#   [ 1, 10, 10,  0]])

编辑:

根据您的问题,如果您想在数组元素不大于maxVal或小于的情况下更改值,那么minVal您可以执行或您想要的任何其他逻辑:

import numpy as np
q = q = np.random.rand(4,4)
minVal = 0.3
maxVal = 0.9
qq = np.where(q > 0.8, 1, np.where(q < 0.3, 0, 2 * q))

在哪里q

[[0.63604995 0.18637738 0.90680287 0.64617278]
 [0.97435344 0.04670638 0.3510053  0.71613776]
 [0.17973416 0.50296747 0.35085383 0.853201  ]
 [0.27820978 0.69438172 0.96186074 0.96625938]]

并且qq是:

[[1.27209991 0.         1.         1.29234556]
 [1.         0.         0.7020106  1.43227553]
 [0.         1.00593493 0.70170767 1.        ]
 [0.         1.38876345 1.         1.        ]]

推荐阅读