首页 > 解决方案 > 为什么 numba.jit 和 python 会产生不同的结果?

问题描述

在下面的代码中,numba.jit 产生 7732.96... 而 python 产生 -6351.97...(为简洁起见,省略了数字)。我能做些什么来解决这个问题?这是 numba 的错误还是我的编码错误?我在 Spyder 3 上使用了 Python 3.7 (anaconda)。

from numba import jit
import numpy as np

@jit(nopython=True)
def test(n):
     sum = 0.0
     arr = np.arange(2, n)
     for x in np.sin(np.cos(arr ** 2)):
           sum += x
     return sum


a = test(100000000)
print(a)

标签: pythonnumba

解决方案


我无法重现您的错误:

from numba import jit
import numpy as np

def test(n):
    sum = 0.0
    arr = np.arange(2, n)
    for x in np.sin(np.cos(arr ** 2)):
        sum += x
    return sum

testnb = jit(nopython=True)(test)

N = 100000000
print(test(N))
print(testnb(N))
# 7732.969676855288
# 7732.969676855337

我正在使用 numba 0.45.1、python 3.7.3 和 numpy 1.16.4。我最初的猜测是存在某种浮点问题,在非 jitted 形式中,sum是一个具有无限精度的 python 值,而在 jitted 代码中,sum根据您的系统键入特定的 float32 或 float64。但是对于您的特定系统,我不确定发生了什么。


推荐阅读