首页 > 解决方案 > 在此函数上使用 numba 会引发此错误。可能是什么问题?

问题描述

这是我试图加速的功能。我正在使用使用 python 2.7 的 spyder 最新版本,Numba 版本是 0.38.0 -

@nb.njit(fastmath = True, parallel = True, error_model = "numpy", nogil = True)
def fun(a, b, c, d, ef):
    # start = time.time()
    m_d = np.array([-40, -40, -40])
    f = np.zeros((128, 128, 128), np.complex64)


    for i in range(0, len(d)):
        x = nb.int64(math.floor((ef[i][0] - m_d[0]) / 1.2))
        y = nb.int64(math.floor((ef[i][1] - m_d[1]) / 1.2))
        z = nb.int64(math.floor((ef[i][2] - m_d[2]) / 1.2))
        f[x][y][z] = complex(d[i])


    e  = 0
    g = np.zeros((128, 128, 128), np.complex64)
    X = Y = Z = 128

    for i in range(len(a)):
        x = a[i]
        y = b[i]
        z = c[i]
        for x2 in range(x - 1, x + 5):
            for y2 in range(y - 1, y + 5):
                for z2 in range(z - 1, z + 5):
                    if (-1 < x < X and
                        -1 < y < Y and
                        -1 < z < Z and
                        (x != x2 or y != y2 or z != z2) and
                        (0 <= x2 < X) and
                        (0 <= y2 < Y)and
                        (0 <= z2 < Z)):
                            q = f[x2][y2][z2]
                            di = np.sqrt((x - x2) ** 2 + (y - y2) ** 2 + (z - z2) ** 2) * 1.2
                            if di <= 6 and di >= 2:
                                e = 4
                            elif di > 6 and di < 8:
                                e = 38 * di - 224
                            elif di >= 8:
                                e = 80
                            else:
                                continue
                            value = q / (e * di)
                            g[x][y][z] = g[x][y][z] + value


    # print "fun : ", time.time() - start
    return g

错误是 -

task = get()
TypeError: ('__init__() takes exactly 3 arguments (2 given)', <class 'numba.errors.LoweringError'>, ('Failed at nopython (nopython mode backend)\nreflected list(array(float32, 1d, C)): unsupported nested memory-managed object\n\nFile "test_numba_errorful.py", line 702:\ndef fun(a, b, c, d, ef):\n    <source elided>\n    # m_d = np.array([-40, -40, -40])\n    f = np.zeros((128, 128, 128), np.complex64)\n    ^\n[1] During: lowering "ef = arg(4, name=)"

在修复了一些多处理开销之后,我现在收到了这个错误 -

File "/usr/local/lib/python2.7/dist-packages/numba/dispatcher.py", line 360, in _compile_for_args
    raise e

LoweringError: reflected list(array(float32, 1d, C)): unsupported nested memory-managed object

File "test_numba_errorful.py", line 702:
def fun(a, b, c, d, ef):
    <source elided>
    # m_d = np.array([-40, -40, -40])
    f = np.zeros((128, 128, 128), np.complex64)
    ^

是什么导致了错误?我该如何纠正它?

标签: pythoninitializationnested-loopsjitnumba

解决方案


一些东西:

  1. 模式抖动函数time中不支持该函数。nopython在此处查看支持的 python 功能列表:

http://numba.pydata.org/numba-doc/latest/reference/pysupported.html

  1. 您只能使用print作为功能进行打印。from __future__ import print_function使用 python 2 时你会想要。

更改上述两项允许代码使用 Numba 0.39 对输入进行猜测(我只尝试了标准的 numpy 数组)。但是,对于您正在使用的版本,从错误看来,您可能正在使用类似列表列表或 numpy 数组列表之类的东西,早期版本不支持这些东西。

另一个普遍的建议是,在处理多维数组时,访问总是x[i,j]x[i][j]性能更好。


推荐阅读