首页 > 解决方案 > 在 Numba 中嵌套类

问题描述

我正在尝试在 Numba 中编写用于计算三次样条的嵌套类。第一个案例看起来像这样

spec = [
    ('x', float64[:]),               # a simple scalar field
    ('y', float64[:]),
    ('b', nb.types.List(nb.float64) ),
    ('c', float64[:]),
    ('d', nb.types.List(nb.float64)),
    ('w', nb.types.List(nb.float64)),
    ('nx', int32),
    ('a', nb.types.List(nb.float64)),
    
]

@jitclass(spec)
class Spline:
    """
    Cubic Spline class
    """

并且工作正常。然后我将 2d 案例写为

spline_1d_type = deferred_type()
spline_1d_type.define(Spline.class_type.instance_type)

spec_2d = [('s0',float64[:]),('sy',spline_1d_type),('sx',spline_1d_type), ("ds",nb.types.List(nb.float64)),('x', float64[:]),               # a simple scalar field
    ('y', float64[:])]


@jitclass(spec_2d)
class Spline2D:
    """
    2D Cubic Spline class
    """

但这给了我一个错误

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython mode backend)
LLVM IR parsing error
<string>:498:278: error: base element of getelementptr must be sized
  %".1077" = getelementptr inbounds {{i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}, %"deferred.139921649954192.data", %"deferred.139921649954192.data", {i8*, i8*}, {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}, {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}}, {{i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}, %"deferred.139921649954192.data", %"deferred.139921649954192.data", {i8*, i8*}, {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}, {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}}* %".1076", i32 0, i32 3
                                                                                                                                                                                                                                                                                     ^

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.misc.ClassInstanceType'>, '_Spline2D__calc_s') for instance.jitclass.Spline2D#7f420c3defa0<s0:array(float64, 1d, A),sy:DeferredType#139921649954192,sx:DeferredType#139921649954192,ds:list(float64)<iv=None>,x:array(float64, 1d, A),y:array(float64, 1d, A)>)
During: typing of call at <ipython-input-7-8958d720d848> (161)


File "<ipython-input-7-8958d720d848>", line 161:
    def __init__(self, x, y):
        self.s0 = self.__calc_s(x, y)
        ^

During: resolving callee type: jitclass.Spline2D#7f420c3defa0<s0:array(float64, 1d, A),sy:DeferredType#139921649954192,sx:DeferredType#139921649954192,ds:list(float64)<iv=None>,x:array(float64, 1d, A),y:array(float64, 1d, A)>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.Spline2D#7f420c3defa0<s0:array(float64, 1d, A),sy:DeferredType#139921649954192,sx:DeferredType#139921649954192,ds:list(float64)<iv=None>,x:array(float64, 1d, A),y:array(float64, 1d, A)>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>

我认为这是因为 Spline1d 类中元素的大小可能是未定义的,但我不确定。如果有人能指出我正确的方向,那就太好了

标签: pythonpython-3.xjitnumba

解决方案


推荐阅读