首页 > 解决方案 > 使用 numpy 提高 FOR 循环性能

问题描述

以下代码片段需要 0.7 秒。我想提高速度。基本上,代码将 b 中具有与 a 相同索引的所有值相加,并将它们存储在同一位置,在 a 的索引中找到,但在不同的数组中。所以基本上数组 a 保存的值范围为 0-255,代表临时数组的索引。

a = np.random.randint(256, size=(40000,2))
b= np.arange(1280000).reshape(40000, 32)
temp = np.zeros((1,32,256,256)) 

for indx, pnt in enumerate(a):
    temp[0,:,pnt[0],pnt[1]] += b[indx,:]

谢谢。

标签: numpyloopsfor-loop

解决方案


这个答案可能有点偏左,但我认为不值得向后弯腰尝试矢量化。

numba改为在这个问题上抛出一些;

from numba import jit

@jit
def compute_stuff(a, b):
    temp = np.zeros((1,32,256,256)) 
    for indx, pnt in enumerate(a):
        temp[0,:,pnt[0],pnt[1]] += b[indx,:]
    return temp

而且它至少要快一点。

似乎您也知道a必须小于 256,因此您可以通过指定数组的数据类型来节省内存并可能获得一些性能;

a = np.random.randint(256, size=(40000,2), dtype=np.uint8)

推荐阅读