首页 > 解决方案 > 在 numpy 中使用布尔索引添加值的 Numba 同义词

问题描述

我正在尝试创建更高效​​的代码,但我坚持实施以下 Numba 版本:

import numpy as np

a = np.array([[0, 0, 0, 0],
              [0, 0, 0, 0]])

bool_idx = np.array([True, False, False, True])

a[0, bool_idx] += 3
a

array([[3, 0, 0, 3],
       [0, 0, 0, 0]])

不幸的是,将此代码移植到带有 numba 的函数时出现错误:

@njit
def add_to_arr(a, idx, arr_bool, add):
    arr[idx, arr_bool] += 3
    return arr

add_to_arr(a=a, idx=0, arr_bool=bool_idx, add=3)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int32, 2d, C), (int64, array(bool, 1d, C)))

标签: pythonnumpynumba

解决方案


在这种情况下,Numba 似乎只允许在数组的第一维上进行高级索引。我们可以重写函数(也纠正一个轻微的错字)来适应这一点,只需使用转置和反转索引:

@njit 
def add_to_arr(a, idx, arr_bool, add): 
    a.T[arr_bool, idx] += 3 
    return a 

add_to_arr(a, 0, bool_idx, 3)     

这对我有用,给出:

array([[3, 0, 0, 3],
       [0, 0, 0, 0]])

文档说高级索引只允许在一个维度中,但没有指定这需要是第一个维度,所以这可能是一个错误。


推荐阅读