python-2.7 - 我可以向 numba.njit 提供一些参数类型,但让它推断其余的吗?
问题描述
我正在使用 numba.njit;它可以轻松推断类型,H1()
但不能推断H2()
我提供函数F
作为参数之一的位置。
有没有办法可以告诉我numba.njit
的类型F
,但让它推断剩余的类型,所以我不必提供整体签名?
当我运行以下代码时...
import numpy as np
import numba
@numba.njit
def F1(s):
return 1/s
@numba.njit
def H1(s, p):
return F1(s)/(F1(s)+p['tau'])
@numba.njit
def H2(s, p, F):
return F(s)/(F(s)+p['tau'])
def prepare_params(x=None):
try:
shape = x.shape
except LookupError:
shape = ()
f64 = np.dtype(np.float64)
p = np.zeros(shape=shape, dtype=[('tau',f64),
('something',f64)])
p['tau'] = 0.001
p['something'] = 1
return p
s = np.logspace(0,2,5)*1j
p = prepare_params(s)
print "H1=", H1(s,p)
print "H2=", H2(s,p,F1)
我明白了:
H1= [ 0.99999900-0.001j 0.99999000-0.00316225j 0.99990001-0.009999j
0.99900100-0.03159119j 0.99009901-0.0990099j ]
H2=
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-18-2a1359160fe9> in <module>()
30 p = prepare_params(s)
31 print "H1=", H1(s,p)
---> 32 print "H2=", H2(s,p,F1)
c:\app\python\anaconda\2\lib\site-packages\numba\dispatcher.pyc in _compile_for_args(self, *args, **kws)
328 for i, err in failed_args))
329 e.patch_message(msg)
--> 330 raise e
331
332 def inspect_llvm(self, signature=None):
TypingError: Caused By:
Traceback (most recent call last):
File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 238, in run
stage()
File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 452, in stage_nopython_frontend
self.locals)
File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 865, in type_inference_stage
infer.propagate()
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 844, in propagate
raise errors[0]
TypingError: Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
typeinfer.add_type(self.dst, ty, loc=self.loc)
File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
six.reraise(type(newerr), newerr, sys.exc_info()[2])
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
yield
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
assert ty.is_precise()
InternalError:
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------
File "<ipython-input-18-2a1359160fe9>", line 15
Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
typeinfer.add_type(self.dst, ty, loc=self.loc)
File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
six.reraise(type(newerr), newerr, sys.exc_info()[2])
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
yield
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
assert ty.is_precise()
InternalError:
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------
File "<ipython-input-18-2a1359160fe9>", line 15
This error may have been caused by the following argument(s):
- argument 2: cannot determine Numba type of <class 'numba.targets.registry.CPUDispatcher'>
解决方案
没关系,当我写这个问题时,我使用的是 Numba 0.35;我已经升级到 0.46 并且工作正常:
H1= [ 0.99999900-0.001j 0.99999000-0.00316225j 0.99990001-0.009999j
0.99900100-0.03159119j 0.99009901-0.0990099j ]
H2= [ 0.99999900-0.001j 0.99999000-0.00316225j 0.99990001-0.009999j
0.99900100-0.03159119j 0.99009901-0.0990099j ]
我可以将函数作为参数传递给 jited 函数吗?
从 Numba 0.39 开始,只要函数参数也经过 JIT 编译,您就可以:
推荐阅读
- pandas - pandas:如何使用 predict_proba 获取每个客户的概率
- oracle - 什么是预言机事件记录?
- laravel - 在 Lumen 包开发中未触发包级 Composer
- java - 在 Java 编译时生成枚举
- excel - 工作簿是 VBA Excel 中的类还是集合?
- json - 如何查看 JSON 文件的内容?数据类型错误
- php - 根据多个答案显示表单字段
- java - 如何在android中将圆圈作为标记添加到mapbox?
- javascript - 在 UI5 中编辑/更新绑定数据时未发生数据绑定
- python - 如何构建 python 项目以将其部署为公共网站?