首页 > 解决方案 > Numba 降低了简单 for 循环函数的性能

问题描述

我正在处理相对较大的数据集(100,000 多个元素),需要为这些数据集创建一个邻接矩阵。

我编写了一个非常基本的 for 循环,可以为给定的连接节点(nx2)完成此操作

nodes = np.random.randint(20000, size=(20000, 2))

def adjMat(node_list):
    n = np.max(node_list)
    A = np.zeros((n, n))
    for tail, head in node_list:
        A[tail-1, head-1] = 1
    return A

这工作正常,并不像我想象的那么慢,但假设我可以通过使用 numba 来显着提高性能这个超级简单的功能。

所以我添加了两个 jitted 函数(一个使用并行)来查看性能差异。我也刚刚包含了networkx,看看它是否得到了很好的优化。

@njit()
def adjMat_numba(node_list):
    n = np.max(node_list)
    A = np.zeros((n, n))
    for tail, head in node_list:
        A[tail-1, head-1] = 1
    return A

@njit(parallel = True)
def adjMat_numba_para(node_list):
    n = np.max(node_list)
    A = np.zeros((n, n))
    for tail, head in node_list:
        A[tail-1, head-1] = 1
    return A

def getAdjacenyList(node_list):
    G = nx.Graph([e for e in node_list])
    A = nx.convert.to_dict_of_lists(G)

    return A

这是我在 20000 对连接节点上测试的输出:

%timeit a = adjMat(nodes)
112 ms ± 3.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit b = adjMat_numba(nodes)
1.34 s ± 41.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit c = adjMat_numba_para(nodes)
251 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit d = getAdjacenyList(nodes)
149 ms ± 3.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

令我惊讶的是,使用 numba 实际上会使函数慢几倍,而且即使在并行模式下,它仍然没有 for 循环那么快。Numba 似乎也比 for 循环使用更多的内存。另外,我很惊讶 networkx 比 for 循环慢 - 我本来希望一个唯一目的是处理这类问题的库会更快。

我的 numba 装饰器有问题吗?有没有更好的选择来快速有效地创建邻接矩阵?

我正在使用 pycharm 在 12 核 linux 桌面上运行这些测试。

标签: pythonnumpynumbaadjacency-matrix

解决方案


In [22]: node_list = np.array([[0,1],[1,4],[4,2],[2,0]])                                                                                                                                     

In [27]: adjMat(node_list+1)                                                                         
Out[27]: 
array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]])

一种非迭代的 numpy 方法:

In [28]: res = np.zeros((5,5))                                                                       
In [29]: res[node_list[0],node_list[1]] = 1                                                          
In [30]: res                                                                                         
Out[30]: 
array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

这真的需要numba吗?


推荐阅读