首页 > 解决方案 > Numba jited len() 比纯 Python len() 慢

问题描述

我正在学习 numba 并遇到了我不理解的这种“奇怪”行为。我尝试使用以下代码(在 iPython 中,用于计时):

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

def py_len(seq):
    return len(seq)

##
t = np.random.rand(1000)

%timeit nb_len(t)
%timeit py_len(t)

结果如下(实际上是第二次运行由于numba的编译):

258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

纯 python 版本的速度是 numba 版本的两倍。我也尝试过签名@nb.njit( nb.int32(nb.float64[:]) ),但结果仍然相同。

我在某个地方犯错了吗?

谢谢你。

标签: pythonnumpynumba

解决方案


增加时间的不是 len() 部分。使用输入参数调用 jit 函数会增加开销,这就是您看到的时间差。

import numba as nb

def py_pass(i):
    return i

@nb.njit()
def nb_pass(i):
    return i

%timeit py_pass(1)
%timeit nb_pass(1)

输入参数的结果

102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

有趣的是,如果你不需要向 jit 函数传递任何东西,它会更快:

def py_pass():
    return 1

@nb.njit()
def nb_pass():
    return 1

%timeit py_pass()
%timeit nb_pass()

没有输入参数的结果

96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

推荐阅读