首页 > 解决方案 > Numba 缓存和迭代求解

问题描述

我正在尝试使用 NumbaLSODA(参见此处)来迭代地解决许多系统。但是,似乎我无法将该函数放入 numba 缓存中,因此每次迭代都有很大的开销。我想知道是否有人可以提供帮助。

这是我的代码。首先,我将所有与 numba 相关的函数放在一个文件中

numba_func.py

NOPYTHON = True
CACHE = True
PARALLEL = True

@nb.cfunc(lsoda_sig, nopython=NOPYTHON, cache=CACHE)
def rhs(x, i, di, p):
    di[0] = ...
    di[1] = ...
    di[2] = ...

funcptr = rhs.address

@nb.njit('(float64)(float64, float64, float64, int16)', nopython=NOPYTHON, parallel=PARALLEL, cache=CACHE)
def solve(a, b, c, funcptr):

    w0 = np.array([a, b, ...], dtype=np.float64)
    p  = np.array([c], dtype=np.float64)

    x = np.linspace(0, 100, 500)

    usol, success = lsoda(funcptr, w0, x, data=p)
    
    return usol[-1][1]

然后我使用另一个文件一个接一个地解决系统

from numba_func import solve, funcptr

gs = []
for a, b, c in zip(as, bs, cs):
    gs = np.append(gs, solve(a, b, c, funcptr))

我收到以下警告:

NumbaWarning: Cannot cache compiled function "solve" as it uses dynamic globals (such as ctypes pointers and large global arrays)

我想这个想法是正确传递变量 funcptr 以便 numba 很高兴但到目前为止我失败了......

标签: numba

解决方案


推荐阅读