python - 在 numba 的 jitclass 中索引多维 numpy 数组
问题描述
我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个较大的数组中。小数组设置由索引列表定义的大数组的特定位置。
以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作
import numpy as np
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution 1 using pure python
def nonNumbaFunction1(self, idx, values):
self.A[idx[:, None], idx] = values
# solution 2 using pure python
def nonNumbaFunction2(self, idx, values):
self.A[np.ix_(idx, idx)] = values
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.nonNumbaFunction1(idx, values)
print(f'A =\n{obj.A}')
obj.nonNumbaFunction2(idx, values)
print(f'A =\n{obj.A}')
这两个函数nonNumbaFunction1
都nonNumbaFunction2
不能在 numba 类中工作。所以我目前的解决方案看起来像这样,在我看来这不是很好
import numpy as np
from numba import jitclass
from numba import int64, float64
from collections import OrderedDict
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution for numba jitclass
def numbaFunction(self, idx, values):
for i in range(len(values)):
idxi = idx[i]
for j in range(len(values)):
idxj = idx[j]
self.A[idxi, idxj] = values[i, j]
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.numbaFunction(idx, values)
print(f'A =\n{obj.A}')
所以我的问题是:
- 有谁知道 numba 中这种索引的解决方案,还是有另一种矢量化解决方案?
- 有更快的解决方案
nonNumbaFunction1
吗?
知道插入的数组很小(4x4 到 10x10)可能很有用,但是这个索引出现在嵌套循环中,所以它也必须快速安静!后来我也需要对三维对象进行类似的索引。
解决方案
由于 numba 对索引支持的限制,我认为没有比自己写出 for 循环更好的方法了。为了使其跨维度通用,您可以使用generated_jit
装饰器进行专业化。像这样的东西:
def set_2d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
target[idx[i], idx[j]] = values[i, j]
def set_3d(target, values, idx):
for i in range(values.shape[0]):
for j in range(values.shape[1]):
for k in range(values.shape[2]):
target[idx[i], idx[j], idx[k]] = values[i, j, l]
@numba.generated_jit
def set_nd(target, values, idx):
if target.ndim == 2:
return set_2d
elif target.ndim == 3:
return set_3d
然后,这可以在你的 jitclass 中使用
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
def numbaFunction(self, idx, values):
set_nd(self.A, values, idx)
推荐阅读
- c++ - 将多个子类包含到 C++ 文件中的最佳方法是什么?
- laravel - Laravel 刀片中的未定义视图
- javascript - AG-Grid 标题单元格选择
- ios - 在 applicationdidfinishlaunching 中拍摄 ios 启动屏幕的快照
- javascript - 根据 [Slick] 滑块选择加载 div
- javascript - 获得对 CSSMediaRule 的可读访问权限
- parameters - 如何将 alpha(透明度)传递给 seaborn.jointplot?
- java - java.lang.Exception:谁在打电话?通过 PJSUA 拨打电话时
- python - 检查输入时出错:预期 lstm_29_input 的形状为 (None, None, 2) 但得到的数组的形状为 (51, 1, 10)
- dart - Dart 捕获 http 异常