首页 > 解决方案 > 为什么在使用列表时用 numba 计算总和会更慢?

问题描述

这是我的代码:

@numba.jit( )
def dis4(x1,x2):
    s=0.0
    for i in range(len(x1)):
        s+=(x1[i]-x2[i])**2
    return math.sqrt(s)
x1=[random.random() for _ in range(m)]
x2=[random.random() for _ in range(m)]
%timeit dis4(x1,x2)

每个循环 3.32 毫秒 ± 37.8 微秒(平均值 ± 标准偏差。7 次运行,每次 100 次循环)

相比之下,没有jit.

每个循环 137 µs ± 1.62 µs(7 次运行的平均值 ± 标准偏差,每次 10000 个循环)

标签: pythonlistperformancejitnumba

解决方案


它更慢,因为 numba (默默地)复制了列表。

要了解为什么会发生这种情况,您需要知道 numba 具有 object-mode 和 nopython-mode。在对象模式下,它可以对 Python 数据结构进行操作,但是它不会比普通的 Python 函数快多少,甚至会更慢(至少在一般情况下,有非常罕见的例外)。在 nopython 模式下 numba 不能对 Python 数据结构(如 )进行操作list,因此为了使lists 工作,它必须使用非 Python 列表。要从 Python 列表创建这样的非 Python 列表(称为反射列表),它必须复制和转换列表内容。

在您的情况下,这种复制和转换会使其速度变慢。

这也是为什么人们通常应该避免使用非数组参数或使用 numba 函数返回的原因。数组的内容不需要转换,至少如果数组的 dtype 受 numba 支持,那么这些是“安全的”。

如果这些数据结构(列表、元组、集合)被限制在 numba 中,它们很好 - 但是当它们跨越 numba⭤ Python 边界时,它们必须被复制,这(几乎)总是会抵消所有的性能提升。


只是为了展示函数如何使用数组执行:

import math
import random
import numba as nb
import numpy as np

def dis4_plain(x1,x2):
    s=0.0
    for i in range(len(x1)):
        s+=(x1[i]-x2[i])**2
    return math.sqrt(s)

@nb.jit
def dis4(x1,x2):
    s=0.0
    for i in range(len(x1)):
        s+=(x1[i]-x2[i])**2
    return math.sqrt(s)

m = 10_000
x1 = [random.random() for _ in range(m)]
x2 = [random.random() for _ in range(m)]
a1 = np.array(x1)
a2 = np.array(x2)

定时:

dis4(x1, x2)
dis4(a1, a2)

%timeit dis4_plain(x1, x2)
# 2.71 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit dis4(x1, x2)
# 24.1 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit dis4(a1, a2)
# 14 µs ± 608 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

因此,虽然使用 list 和 the 慢 10 倍numba.jit,但使用数组的 jitted 函数几乎比使用列表的 Python 函数快 200 倍。


推荐阅读