python - 在 numpy 中使用布尔索引添加值的 Numba 同义词
问题描述
我正在尝试创建更高效的代码,但我坚持实施以下 Numba 版本:
import numpy as np
a = np.array([[0, 0, 0, 0],
[0, 0, 0, 0]])
bool_idx = np.array([True, False, False, True])
a[0, bool_idx] += 3
a
array([[3, 0, 0, 3],
[0, 0, 0, 0]])
不幸的是,将此代码移植到带有 numba 的函数时出现错误:
@njit
def add_to_arr(a, idx, arr_bool, add):
arr[idx, arr_bool] += 3
return arr
add_to_arr(a=a, idx=0, arr_bool=bool_idx, add=3)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int32, 2d, C), (int64, array(bool, 1d, C)))
解决方案
在这种情况下,Numba 似乎只允许在数组的第一维上进行高级索引。我们可以重写函数(也纠正一个轻微的错字)来适应这一点,只需使用转置和反转索引:
@njit
def add_to_arr(a, idx, arr_bool, add):
a.T[arr_bool, idx] += 3
return a
add_to_arr(a, 0, bool_idx, 3)
这对我有用,给出:
array([[3, 0, 0, 3],
[0, 0, 0, 0]])
文档说高级索引只允许在一个维度中,但没有指定这需要是第一个维度,所以这可能是一个错误。
推荐阅读
- python-3.x - AttributeError:“LogisticRegressionTrainingSummary”对象没有属性“areaUnderROC”
- c# - 使用 DB 在应用程序中进行更改时如何更新原始 SQLite DB 文件
- java - 更改包名而不更改文件路径
- sql - 如何在 presto SQL 中按月分组
- json - 使用 JQ 命令将文本文件转换为 json
- sql - SQL 查询以选择列中值的特定部分
- spring-boot - 如何将单独的引导服务器设置为绑定的 DLT
- java - 如果使用 Sping 出现异常或无效输入,如何重定向 URL?
- python - 从字典中删除特定值
- javascript - 检查页面中是否存在弹出窗口?以及如何在 IBM RFT 中处理它