python - njit numba 函数的高级索引替代方案
问题描述
给定以下最小可重现示例:
import numpy as np
from numba import jit
# variable number of dimensions
n_t = 8
# q is just a partition of n
q_ddl = 2
n_ddl = 3
np.random.seed(42)
df = np.random.rand(q_ddl*n_t,q_ddl*n_t)
# index array
# ddl_nl is a set of np.arange(n_ddl), ex: [0,1] ; [0,2] or even [0] ...
ddl_nl = np.array([0,1])
ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))
@jit(nopython=True)
def foo(df,ij):
out = np.zeros((n_t,n_ddl,n_ddl))
for i in range(0,n_t):
d_i = np.zeros((n_ddl,n_ddl))
# (q_ddl,q_ddl) non zero values into (n_ddl,n_ddl) shape
d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
# to check possible solutions
out[i,...] = d_i
return out
out_foo = foo(df,ij)
该功能在禁用foo
时运行良好,@jit(nopython=True)
但在启用时抛出以下错误:
TypeError: unsupported array index type array(int64, 2d, C) in UniTuple(array(int64, 2d, C) x 2)
这发生在广播操作期间d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
。然后,我确实尝试ij
使用类似的东西来展平二维索引数组,d_i[ij[0].ravel(), ij[1].ravel()] = df[i::n_t,i::n_t].ravel()
这给了我相同的输出,但现在又出现了另一个错误:
NotImplementedError: only one advanced index supported
所以我终于尝试通过使用经典的 2 嵌套for
循环结构来避免这种情况:
tmp = df[i::n_t,i::n_t]
for k,r in enumerate(ddl_nl):
for l,c in enumerate(ddl_nl):
d_i[r,c] = tmp[k,l]
它与启用的装饰器一起工作并给出预期的结果。
但是我不能停止思考我在这里缺少的这个 numpy 2d-array 广播操作是否有任何兼容 numba 的替代方案?任何帮助将不胜感激。
解决方案
避免花哨的索引
还要避免使用全局变量(它们在编译时是硬编码的)并保持你的代码尽可能简单(简单意味着只有一个露水循环,if/else,...)。如果ddl_nl
数组真的只使用 np.arange 构造,则根本不需要这个数组。
例子
import numpy as np
from numba import jit
@jit(nopython=True)
def foo_nb(df,n_ddl,n_t,ddl_nl):
out = np.zeros((n_t,n_ddl,n_ddl))
for i in range(0,n_t):
for ii in range(ddl_nl.shape[0]):
ind_1=ddl_nl[ii]
for jj in range(ddl_nl.shape[0]):
ind_2=ddl_nl[jj]
out[i,ind_1,ind_2] = df[i+ii*n_t,i+jj*n_t]
return out
计时
#Testing and compilation
A=foo(df,ij)
B=foo_nb(df,n_ddl,n_t,ddl_nl)
print(np.allclose(A,B))
#True
%timeit foo(df,ij)
#16.8 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_nb(df,n_ddl,n_t,ddl_nl)
#674 ns ± 2.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
推荐阅读
- python - 以编程方式从 Swagger 定义生成客户端
- mysql - 如果表 2 中的日期值在表 1 中找到的日期范围内,则从表 2 中获取日期值或不可用
- c# - 使用正则表达式根据出现的不同字符拆分字符串
- c++ - 写入 Excel 文件的错误 - C++
- ruby-on-rails - 如何在 Rails 视图中显示中央标准时间 (CST)?
- kubernetes - 为什么 operator-courier verify 抱怨版本不匹配?
- python - Python:使用多个变量最小化多个函数
- python - 使用 rpy2 加载 R 包时 R 内核崩溃
- php - 检查 Woocommerce 中的产品类别页面是否为空
- apache-spark - 具有多列的 DataFrame 过滤器在 Spark 2.2(scala)中无法使用 && 运算符