首页 > 解决方案 > 如何使用 Numba 释放 GIL 并使用多个线程?

问题描述

我想为每个唯一值作为键创建一个字典,引用数组的一列。下面的代码只是显示了我的目标。

dic_value0 = {}
arr = np.zeros(1000000, dtype=np.dtype([('order', '<u4'), ('time', '<u8'), ('value0', '<f8'), ('value1', '<u2'), ('value2', '<i8')])
arr['value0'] = np.random.randint(0, 500, size=1000000)

for value0 in np.unique( arr['value0'] ):
    dic_value0[value0] = arr[ arr['value0'] == value0 ] 

我的另一个目标是快速做到这一点。为此,我认为我可以尝试使用 Numba 和 Threadpool。

import concurrent.futures
import time
import numba as nb
import numpy as np

dtype = np.dtype([('order', '<u4'), ('time', '<u8'), ('value0', '<f8'), ('value1', '<u2'), ('value2', '<i8')])

arr = np.zeros(1000000, dtype=dtype)
arr['order']  = np.arange(1000000)
arr['time']   = np.arange(1000000)
arr['value0'] = np.random.randint(0, 500, size=1000000)
arr['value1'] = np.random.randint(0, 500, size=1000000)
arr['value2'] = np.random.randint(0, 500, size=1000000)


def returnSelectedArr(arr_dset, value0):
    return arr_dset[ arr_dset['value0'] == value0 ] 
    
@nb.jit(nopython=True, nogil=True)
def returnSelectedArrNb(arr_dset, value0):
    return arr_dset[ arr_dset['value0'] == value0 ] 
    

arr_uniqueValue0 = np.unique( arr['value0'] )
dic_value0 = {}
dic_value0pool = {}
dic_value0pool4 = {}

#compilation
returnSelectedArrNb(arr, 0)


start = time.time()
for value0 in arr_uniqueValue0:
    dic_value0[value0] = returnSelectedArr(arr, value0)

end = time.time()
print("Elapsed (simple function) = %s" % (end - start))


executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
start = time.time()
for value0 in arr_uniqueValue0:
    dic_value0pool4[value0] = executor.submit(returnSelectedArrNb, arr, value0).result()    
end = time.time()
print("Elapsed (thread pool - 4 workers) = %s" % (end - start))


executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
start = time.time()
for value0 in arr_uniqueValue0:
    dic_value0pool[value0] = executor.submit(returnSelectedArrNb, arr, value0).result()
end = time.time()
print("Elapsed (thread pool - 1 worker) = %s" % (end - start))

但它仍然给我一个糟糕的表现。(甚至使用 1 个线程也比使用 4 个线程更快)

Elapsed (simple function) = 3.479099988937378
Elapsed (thread pool - 4 workers) = 2.8787758350372314
Elapsed (thread pool - 1 worker) = 2.629206657409668

我想知道我做错了什么以及如何提高性能。

标签: pythonthreadpoolnumba

解决方案


推荐阅读