python - Numba jited len() 比纯 Python len() 慢
问题描述
我正在学习 numba 并遇到了我不理解的这种“奇怪”行为。我尝试使用以下代码(在 iPython 中,用于计时):
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
def py_len(seq):
return len(seq)
##
t = np.random.rand(1000)
%timeit nb_len(t)
%timeit py_len(t)
结果如下(实际上是第二次运行由于numba的编译):
258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
纯 python 版本的速度是 numba 版本的两倍。我也尝试过签名@nb.njit( nb.int32(nb.float64[:]) )
,但结果仍然相同。
我在某个地方犯错了吗?
谢谢你。
解决方案
增加时间的不是 len() 部分。使用输入参数调用 jit 函数会增加开销,这就是您看到的时间差。
import numba as nb
def py_pass(i):
return i
@nb.njit()
def nb_pass(i):
return i
%timeit py_pass(1)
%timeit nb_pass(1)
输入参数的结果
102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
有趣的是,如果你不需要向 jit 函数传递任何东西,它会更快:
def py_pass():
return 1
@nb.njit()
def nb_pass():
return 1
%timeit py_pass()
%timeit nb_pass()
没有输入参数的结果
96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
推荐阅读
- c++ - 我们可以从 TXT 文件中读取数据并使用 C 和 C++ 将其保存到 SQL 数据库吗?
- c# - 为什么迭代器在异常上的行为与 LINQ 可枚举不同?
- python - 使用gitlab cicd自动合并分支
- c++ - 添加可变参数函数的错误结果
- android - 如何在活动中显示相机预览?
- twilio - 使用时
TwiML 调用总是失败,调试器中没有错误并且客户端处于活动状态 - java - 如何在 SpringBoot & JavaFx 应用程序的主类中注入 bean
- python - 以列名作为 x 轴绘制直方图
- angular - mat-error 在我创建的组件上不可见
- python - TensorFlow 的可微分汉明损失