python-3.x - 可选择使用 jit 将参数传递给另一个函数
问题描述
我正在尝试 jit 编译 python 函数,并使用可选参数来更改另一个函数调用的参数。
我认为 jit 可能出错的地方是可选参数的默认值为 None,并且 jit 不知道如何处理它,或者至少不知道当它更改为 numpy 数组时如何处理它。请参阅下面的粗略概述:
@jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):
if optionalArg is not None:
out=otherFunc(arg1,optionalArg)
else:
out=otherFunc(arg1)
return out
其中 optionalArg 是 None 或 numpy 数组
一种解决方案是将其转换为如下所示的三个函数,但这感觉有点笨拙,我不喜欢它,特别是因为速度对于这项任务非常重要。
def foo(otherFunc,arg1,optionalArg=None):
if optionalArg is not None:
out=func1(otherFunc,arg1,optionalArg)
else:
out=func2(otherFunc,arg1)
return out
@jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
out=otherFunc(arg1,optionalArg)
return out
@jit(nopython=True)
def func2(otherFunc,arg1):
out=otherFunc(arg1)
return out
请注意,除了调用 otherFunc 之外,还发生了其他事情,这使得使用 jit 值得,但我几乎可以肯定这不是问题所在,因为这在没有 optionalArg 部分之前可以工作,所以我决定不包括它。
对于那些好奇它的 runge-kutta 4 阶实现的人,它带有可选的额外参数以传递给微分方程。如果你想看到整个事情只是问。
回溯相当长,但这里有一些:
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):
File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:
This continues...
inte.rk4 是 foo 的等价物,de2 是 otherFunc,y0、0.001 和 200 只是值,我在上面的问题描述中换成了 arg1,而 vals 是 optionalArg。
当我尝试在省略 vals 参数的情况下运行它时,也会发生类似的事情:
ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):
File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
ysExp=inte.rk4(deExp,y0,0.001,200)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:
This continues...
解决方案
如果您在此处查看文档,您可以optional
在 Numba 中显式指定类型参数。例如(这是文档中的相同示例):
>>> @jit((optional(intp),))
... def f(x):
... return x is not None
...
>>> f(0)
True
>>> f(None)
False
此外,根据有关此 Github 问题的对话,您可以使用以下解决方法来实现可选关键字。我已经修改了 github 问题中提供的解决方案中的代码以适合您的示例:
from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np
np_arr = np.asarray([1,2])
spec = OrderedDict()
spec['x'] = int32
@jitclass(spec)
class Foo(object):
def __init__(self, x):
self.x = x
def otherFunc(self, optionalArg):
if optionalArg is None:
return self.x + 10
else:
return len(optionalArg)
@njit
def useOtherFunc(arg1, optArg):
foo = Foo(arg1)
print(foo.otherFunc(optArg))
arg1 = 5
useOtherFunc(arg1, np_arr) # Output: 2
useOtherFunc(arg1, None) # Output : 15
有关上面显示的示例,请参阅此 colab 笔记本。
推荐阅读
- algorithm - 匈牙利算法用户的最佳策略是什么?
- julia - 创建要循环的整数索引的最佳/最有效方法是什么?
- android - Android“绕过用户批准”第二次以编程方式连接到特定 Wi-Fi 网络不起作用
- gitlab - 根据条件在 GitLab 中跳过 YAML 文件中的块
- mongodb - 如何使用 mongodb 中两个集合的聚合计算利润?
- python - 如何在使用 R markdown 时将 Python 代码的输出拟合到 pdf 中?
- stenciljs - StencilJS:可以在@State 上使用@Watch
- php - 为什么我的数据没有提交到数据库中
- mysql - SQL 错误 (1093) 您不能在 FROM 子句中指定要更新的目标表和子查询
- java - 连接Android Auto后RecyclerView ViewHolders消失