首页 > 解决方案 > 什么时候numba有效?

问题描述

我知道 numba 会产生一些开销,并且在某些情况下(非密集计算)它会比纯 python 慢。但我不知道在哪里画线。是否可以使用算法复杂度的顺序来确定在哪里?

例如,在此代码中添加两个比 5 短的数组 (~O(n)),纯 python 更快:

def sum_1(a,b):
    result = 0.0
    for i,j in zip(a,b):
            result += (i+j)
    return result

@numba.jit('float64[:](float64[:],float64[:])')
def sum_2(a,b):
    result = 0.0
    for i,j in zip(a,b):
            result += (i+j)
    return result

# try 100
a = np.linspace(1.0,2.0,5)
b = np.linspace(1.0,2.0,5)
print("pure python: ")
%timeit -o sum_1(a,b)
print("\n\n\n\npython + numba: ")
%timeit -o sum_2(a,b)

UPDADE:我正在寻找的是类似这里的指南:

“一般准则是为不同的数据大小和算法选择不同的目标。“cpu”目标适用于小数据大小(大约小于 1KB)和低计算强度算法。它的开销最少。“并行”目标适用于中等数据大小(大约小于 1MB)。线程增加了一个小延迟。“cuda”目标适用于大数据大小(大约大于 1MB)和高计算强度算法。将内存传输到并且来自 GPU 会增加大量开销。”

标签: pythonpython-3.xperformancenumba

解决方案


当 numba 生效时,很难划清界限。但是,有一些指标可能无效:

  • 如果你不能使用jitwith nopython=True- 每当你不能在 nopython 模式下编译它时,你要么尝试编译太多,要么它不会明显更快。

  • 如果您不使用数组 - 当您处理传递给 numba 函数的列表或其他类型(其他 numba 函数除外)时,numba 需要复制这些会产生大量开销。

  • 如果已经有一个 NumPy 或 SciPy 函数可以做到这一点 - 即使 numba 对于短数组可以明显更快,它几乎总是对于更长的数组同样快(你也可能很容易忽略这些可以处理的一些常见边缘情况)。

在 numba 比其他解决方案“有点”快的情况下,您可能不想使用它还有另一个原因:必须提前编译 Numba 函数,或者在第一次调用时编译,在某些情况下编译会比你的收获慢得多,即使你调用它数百次。编译时间也增加了:numba 导入速度很慢,编译 numba 函数也增加了一些开销。如果导入开销增加 1-10 秒,那么减少几毫秒是没有意义的。

numba 的安装也很复杂(至少没有 conda),所以如果你想共享你的代码,那么你有一个真正的“高度依赖”。


您的示例缺少与 NumPy 方法和高度优化的纯 Python 版本的比较。我添加了一些比较函数并做了一个基准测试(使用我的库simple_benchmark):

import numpy as np
import numba as nb
from itertools import chain

def python_loop(a,b):
    result = 0.0
    for i,j in zip(a,b):
        result += (i+j)
    return result

@nb.njit
def numba_loop(a,b):
    result = 0.0
    for i,j in zip(a,b):
            result += (i+j)
    return result

def numpy_methods(a, b):
    return a.sum() + b.sum()

def python_sum(a, b):
    return sum(chain(a.tolist(), b.tolist()))

from simple_benchmark import benchmark, MultiArgument

arguments = {
    2**i: MultiArgument([np.zeros(2**i), np.zeros(2**i)])
    for i in range(2, 17)
}
b = benchmark([python_loop, numba_loop, numpy_methods, python_sum], arguments, warmups=[numba_loop])

%matplotlib notebook
b.plot()

在此处输入图像描述

是的,numba 函数对于小数组来说是最快的,但是对于较长的数组,NumPy 解决方案会稍微快一些。Python 解决方案速度较慢,但​​“更快”的替代方案已经比您最初提出的解决方案快得多。

在这种情况下,我会简单地使用 NumPy 解决方案,因为它简短、易读且速度快,除非您处理大量短数组并多次调用该函数 - 那么 numba 解决方案会好得多。


推荐阅读