首页 > 解决方案 > 如何使用 Numba 正确加速?

问题描述

我目前正在加快我的 python 函数的速度。

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)


def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u,v,lon1,lat1):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    return dlon, dlat

如您所见,这是基于 numpy.xml 的简单代码。我看了网上的大部分文章,他们说的只是把@numba.jit作为装饰器放在函数前面,然后我可以使用Numba来加速我的代码。

这是我所做的测试。

u = np.random.randn(10000)
v = np.random.randn(10000)
lon1 = np.random.uniform(-99,-96,10000)
lat1 = np.random.uniform( 23, 25,10000)
print(u)
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

每个循环 5.61 秒 ± 58.7 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

添加 Numba 装饰器

@numba.njit()
def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

@numba.njit()
def distance(u, v, lon1, lat1, R=6.371*1e6):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

每个循环 7.76 秒 ± 64.9 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

正如您在上面看到的,我的 Numba 案例的计算速度比我的纯 python 案例要慢。谁能帮我解决这个问题?

ps:numba
llvmlite 0.32.0rc1
numba 0.49.0rc2 的版本

------ 计算测试关于宏观经济学家的回答。------

根据他的回答,即使 Numba 现在已经足够聪明了,如果我们希望代码是 Numba 装饰的,最好使用普通的“Fortran”/“C”类型的样式。下面展示了我正在考虑的不同方法之间的计算时间比较。

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u,v,lon1,lat1):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    return dlon, dlat
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

每个循环 54 秒 ± 485 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

@numba.jit(nogil=True)
def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

@numba.jit(nogil=True)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u, v, lon1, lat1, R=6.371*1e6):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

每个循环 1 分钟 21 秒 ± 815 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

@numba.njit(nogil=True)
def distance(u, v, lon1, lat1, R=6.371*1e6):
    def d_lat(dlat,R=6.371*1e6):
        return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)             
    def d_lon(lat1,lat2,dlon,R=6.371*1e6):
        return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                               np.cos(np.deg2rad(lat2)) *
                               np.sin(np.deg2rad(dlon)/2)**2)
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = d_lat(lat2 - lat1)
    dlon = d_lon(lat1,lat2,lon2 - lon1)
    return dlon, dlat
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

每个循环 1 分钟 2 秒 ± 239 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

@numba.njit() 
def d_lat(dlat,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2) 

@numba.njit() 
def d_lon(lat1,lat2,dlon,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *  
                           np.cos(np.deg2rad(lat2)) * 
                           np.sin(np.deg2rad(dlon)/2)**2) 

@numba.njit() 
def distance(u, v, lon1, lat1): 
    lon2 = np.empty_like(lon1) 
    lat2 = np.empty_like(lat1) 
    dlon = np.empty_like(lon1) 
    dlat = np.empty_like(lat1) 

    for i in range(len(v)): 
        vi = v[i] 
        if vi > 0: 
            lat2[i] = lat1[i]+1 
            dlat[i] = 1 
        elif vi < 0: 
            lat2[i] = lat1[i]-1 
            dlat[i] = -1 
        else: 
            lat2[i] = lat1[i] 
            dlat[i] = 0 

    for i in range(len(u)): 
        ui = u[i] 
        if ui > 0:  
            lon2[i] = lon1[i]+1 
            dlon[i] = 1 
        elif ui < 0: 
            lon2[i] = lon1[i]-1 
            dlon[i] = -1 
        else: 
            lon2[i] = lon1[i] 
            dlon[i] = 0 

    return d_lon(lat1,lat2,dlon), d_lat(dlat) 
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

每个循环 35.9 秒 ± 537 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

标签: python-3.xnumba

解决方案


有几个问题跳出来了。

首先,您在函数中的计算过于复杂,并且以一种可能不适合 Numba 编译器distance的风格(有很多花哨的索引,例如)编写。lat2[v>0]尽管 Numba 变得越来越聪明,但我发现以简单、面向循环的方式编写代码仍然有很高的回报。

其次,Numba 可以通过可选参数减慢一点。我发现这主要适用R于您的distance函数中的可选项。

解决这两个问题 - 特别是用更简单的循环替换你的矢量化代码以最小化操作 - 我们得到形式的代码

@numba.njit() 
def d_lat(dlat,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2) 

@numba.njit() 
def d_lon(lat1,lat2,dlon,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *  
                           np.cos(np.deg2rad(lat2)) * 
                           np.sin(np.deg2rad(dlon)/2)**2) 

@numba.njit() 
def distance(u, v, lon1, lat1): 
    lon2 = np.empty_like(lon1) 
    lat2 = np.empty_like(lat1) 
    dlon = np.empty_like(lon1) 
    dlat = np.empty_like(lat1) 

    for i in range(len(v)): 
        vi = v[i] 
        if vi > 0: 
            lat2[i] = lat1[i]+1 
            dlat[i] = 1 
        elif vi < 0: 
            lat2[i] = lat1[i]-1 
            dlat[i] = -1 
        else: 
            lat2[i] = lat1[i] 
            dlat[i] = 0 

    for i in range(len(u)): 
        ui = u[i] 
        if ui > 0:  
            lon2[i] = lon1[i]+1 
            dlon[i] = 1 
        elif ui < 0: 
            lon2[i] = lon1[i]-1 
            dlon[i] = -1 
        else: 
            lon2[i] = lon1[i] 
            dlon[i] = 0 

    return d_lon(lat1,lat2,dlon), d_lat(dlat) 

在我的(较慢的)系统上,这将初始编译成本后的时间从大约 7 秒减少到大约 4 秒。在这一点上,我相信成本主要由所有功能的原始成本np.sin, np.cos,np.exp等决定。


推荐阅读