首页 > 解决方案 > 如何加快python中的numpy数组布尔掩码?

问题描述

我需要在 python 中加速以下布尔掩码:

import numpy as np

# test dataset
n=1000000
mask = np.random.choice(a=[False, True], size=(n,), p=[0.8, 0.2])
arr = np.random.rand(n)

# the code I need to speed up:
res = arr[mask]

有什么想法可以让它在 python 中更快吗?

标签: python

解决方案


我写了一个 numba 实现,它的性能几乎和 numpy 一样。

import numba as nb
import numpy as np

@nb.njit(parallel=False, fastmath=True)
def array_masking_float(arr, mask, res):
    j=0
    for i in nb.prange(mask.shape[0]):
        if mask[i] == True:
            res[j] = arr[i]
            j += 1

任何人都可以提出更好的解决方案吗?


推荐阅读