首页 > 解决方案 > 可选择使用 jit 将参数传递给另一个函数

问题描述

我正在尝试 jit 编译 python 函数,并使用可选参数来更改另一个函数调用的参数。

我认为 jit 可能出错的地方是可选参数的默认值为 None,并且 jit 不知道如何处理它,或者至少不知道当它更改为 numpy 数组时如何处理它。请参阅下面的粗略概述:

@jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):

    if optionalArg is not None:
        out=otherFunc(arg1,optionalArg)

    else:
        out=otherFunc(arg1)
    return out

其中 optionalArg 是 None 或 numpy 数组

一种解决方案是将其转换为如下所示的三个函数,但这感觉有点笨拙,我不喜欢它,特别是因为速度对于这项任务非常重要。

def foo(otherFunc,arg1,optionalArg=None):

    if optionalArg is not None:
        out=func1(otherFunc,arg1,optionalArg)
    else:
        out=func2(otherFunc,arg1)
    return out

@jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
    out=otherFunc(arg1,optionalArg)
    return out

@jit(nopython=True)
def func2(otherFunc,arg1):
    out=otherFunc(arg1)
    return out

请注意,除了调用 otherFunc 之外,还发生了其他事情,这使得使用 jit 值得,但我几乎可以肯定这不是问题所在,因为这在没有 optionalArg 部分之前可以工作,所以我决定不包括它。

对于那些好奇它的 runge-kutta 4 阶实现的人,它带有可选的额外参数以传递给微分方程。如果你想看到整个事情只是问。

回溯相当长,但这里有一些:

inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):

  File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
    inte.rk4(de2,y0,0.001,200,vals=np.ones(4))

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:

This continues...

inte.rk4 是 foo 的等价物,de2 是 otherFunc,y0、0.001 和 200 只是值,我在上面的问题描述中换成了 arg1,而 vals 是 optionalArg。

当我尝试在省略 vals 参数的情况下运行它时,也会发生类似的事情:

ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):

  File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
    ysExp=inte.rk4(deExp,y0,0.001,200)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:

This continues...

标签: python-3.xjitnumba

解决方案


如果您在此处查看文档,您可以optional在 Numba 中显式指定类型参数。例如(这是文档中的相同示例):

>>> @jit((optional(intp),))
... def f(x):
...     return x is not None
...
>>> f(0)
True
>>> f(None)
False

此外,根据有关此 Github 问题的对话,您可以使用以下解决方法来实现可选关键字。我已经修改了 github 问题中提供的解决方案中的代码以适合您的示例:

from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np

np_arr = np.asarray([1,2])

spec = OrderedDict()
spec['x'] = int32

@jitclass(spec)
class Foo(object):
    def __init__(self, x):
        self.x = x

    def otherFunc(self, optionalArg):
        if optionalArg is None:
            return self.x + 10
        else:
            return len(optionalArg)
@njit
def useOtherFunc(arg1, optArg):
    foo = Foo(arg1)

    print(foo.otherFunc(optArg))

arg1 = 5

useOtherFunc(arg1, np_arr)   # Output: 2
useOtherFunc(arg1, None)     # Output : 15

有关上面显示的示例,请参阅此 colab 笔记本。


推荐阅读