首页 > 解决方案 > 如何在简单的 python 函数中使用 numba njit?

问题描述

我创建了一个简单的函数,意味着取一个向量并沿几个矩阵将结果乘以,其中 k 只是 B 和 C 的长度。不断发生的是该函数给出了错误:

Invalid use of Function(<built-in function dot>) with argument(s) of type(s):
(reflected list(reflected list(float64)), array(float64, 2d, C))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.

...

def calc(A,k, B, C):
<source elided>
x2 = C[i]
c = np.dot(b,x1)
^

我确保将功能拆分为更易于遵循的方法:

@njit(fastmath = True)
def calc(A,k,B,C):
    b = A
    for i in range(k):
        x1 = B[i]
        x2 = C[i]
        c = np.dot(b,x1)
        D = c + x2
        c = np.tanh(D)
        b = c
    return b

有没有解决这个问题的好方法?如果是这样,正确的方法是什么?

标签: pythonnumba

解决方案


推荐阅读