首页 > 解决方案 > 在 numba 的 jitclass 中索引多维 numpy 数组

问题描述

我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个较大的数组中。小数组设置由索引列表定义的大数组的特定位置。

以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作

import numpy as np

class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution 1 using pure python
    def nonNumbaFunction1(self, idx, values):
        self.A[idx[:, None], idx] = values

    # solution 2 using pure python
    def nonNumbaFunction2(self, idx, values):
        self.A[np.ix_(idx, idx)] = values

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.nonNumbaFunction1(idx, values)
    print(f'A =\n{obj.A}')

    obj.nonNumbaFunction2(idx, values)
    print(f'A =\n{obj.A}')

这两个函数nonNumbaFunction1nonNumbaFunction2不能在 numba 类中工作。所以我目前的解决方案看起来像这样,在我看来这不是很好

import numpy as np

from numba import jitclass      
from numba import int64, float64
from collections import OrderedDict

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution for numba jitclass
    def numbaFunction(self, idx, values):
        for i in range(len(values)):
            idxi = idx[i]
            for j in range(len(values)):
                idxj = idx[j]
                self.A[idxi, idxj] = values[i, j]

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.numbaFunction(idx, values)
    print(f'A =\n{obj.A}')

所以我的问题是:

知道插入的数组很小(4x4 到 10x10)可能很有用,但是这个索引出现在嵌套循环中,所以它也必须快速安静!后来我也需要对三维对象进行类似的索引。

标签: pythonmultidimensional-arrayindexingjitnumba

解决方案


由于 numba 对索引支持的限制,我认为没有比自己写出 for 循环更好的方法了。为了使其跨维度通用,您可以使用generated_jit装饰器进行专业化。像这样的东西:

def set_2d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            target[idx[i], idx[j]] = values[i, j]

def set_3d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            for k in range(values.shape[2]):
                target[idx[i], idx[j], idx[k]] = values[i, j, l]

@numba.generated_jit
def set_nd(target, values, idx):
    if target.ndim == 2:
        return set_2d
    elif target.ndim == 3:
        return set_3d

然后,这可以在你的 jitclass 中使用

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):
    def __init__(self, n, m):
        self.A = np.zeros((n, m))
    def numbaFunction(self, idx, values):
        set_nd(self.A, values, idx)

推荐阅读