首页 > 解决方案 > 我可以向 numba.njit 提供一些参数类型,但让它推断其余的吗?

问题描述

我正在使用 numba.njit;它可以轻松推断类型,H1()但不能推断H2()我提供函数F作为参数之一的位置。

有没有办法可以告诉我numba.njit的类型F,但让它推断剩余的类型,所以我不必提供整体签名?

当我运行以下代码时...

import numpy as np
import numba 


@numba.njit
def F1(s):
    return 1/s

@numba.njit
def H1(s, p):
    return F1(s)/(F1(s)+p['tau'])

@numba.njit
def H2(s, p, F):
    return F(s)/(F(s)+p['tau'])

def prepare_params(x=None):
    try:
        shape = x.shape
    except LookupError:
        shape = ()
    f64 = np.dtype(np.float64)
    p = np.zeros(shape=shape, dtype=[('tau',f64),
                                     ('something',f64)])
    p['tau'] = 0.001
    p['something'] = 1
    return p

s = np.logspace(0,2,5)*1j
p = prepare_params(s)
print "H1=", H1(s,p)
print "H2=", H2(s,p,F1)

我明白了:

H1= [ 0.99999900-0.001j       0.99999000-0.00316225j  0.99990001-0.009999j
  0.99900100-0.03159119j  0.99009901-0.0990099j ]
H2=
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-18-2a1359160fe9> in <module>()
     30 p = prepare_params(s)
     31 print "H1=", H1(s,p)
---> 32 print "H2=", H2(s,p,F1)

c:\app\python\anaconda\2\lib\site-packages\numba\dispatcher.pyc in _compile_for_args(self, *args, **kws)
    328                                 for i, err in failed_args))
    329                 e.patch_message(msg)
--> 330             raise e
    331 
    332     def inspect_llvm(self, signature=None):

TypingError: Caused By:
Traceback (most recent call last):
  File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 238, in run
    stage()
  File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 865, in type_inference_stage
    infer.propagate()
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 844, in propagate
    raise errors[0]
TypingError: Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
    constraint(typeinfer)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
    typeinfer.add_type(self.dst, ty, loc=self.loc)
  File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
    self.gen.throw(type, value, traceback)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
    six.reraise(type(newerr), newerr, sys.exc_info()[2])
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
    yield
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
    assert ty.is_precise()
InternalError: 
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------

File "<ipython-input-18-2a1359160fe9>", line 15

Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
    constraint(typeinfer)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
    typeinfer.add_type(self.dst, ty, loc=self.loc)
  File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
    self.gen.throw(type, value, traceback)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
    six.reraise(type(newerr), newerr, sys.exc_info()[2])
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
    yield
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
    assert ty.is_precise()
InternalError: 
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------

File "<ipython-input-18-2a1359160fe9>", line 15

This error may have been caused by the following argument(s):
- argument 2: cannot determine Numba type of <class 'numba.targets.registry.CPUDispatcher'>

标签: python-2.7numba

解决方案


没关系,当我写这个问题时,我使用的是 Numba 0.35;我已经升级到 0.46 并且工作正常:

H1= [ 0.99999900-0.001j       0.99999000-0.00316225j  0.99990001-0.009999j
  0.99900100-0.03159119j  0.99009901-0.0990099j ]
H2= [ 0.99999900-0.001j       0.99999000-0.00316225j  0.99990001-0.009999j
  0.99900100-0.03159119j  0.99009901-0.0990099j ]

在常见问题解答中找到了这个条目

我可以将函数作为参数传递给 jited 函数吗?

从 Numba 0.39 开始,只要函数参数也经过 JIT 编译,您就可以:


推荐阅读