首页 > 解决方案 > 如何使用地图和列表理解以外的功能加速计算 NumPy 数组中的元素?

问题描述

我有一个包含 3 列和约 20k 行的数据文件(temp.dat)。它看起来像这样:

0 1 100.00
0 2 100.00
0 3 100.00
...
1 10 100.00
1 11 100.00
1 12 100.00
1 13 100.00
1 14 100.00
1 15 100.00
1 16 100.00
1 17 100.00
...
0 10 100.00
0 11 100.00
0 12 100.00
...

我想计算代码中满足以下条件的行数。我尝试了地图和列表理解,但两者似乎都慢得令人难以置信。列表理解快了大约一分钟。

data = np.genfromtxt('temp.dat')
base1, base2, pct = data[:,0], data[:,1], data[:,2]
expected_count = 10000

BASE_NAME = []
for x in range(0,36):
    count1 = sum(map(lambda base1 : base1 == x, base1)) 
    count2 = sum(map(lambda base2 : base2 == x, base2))
    total_count = count1 + count2
    if total_count == expected_count:
        base_num = x
        BASE_NAME.append(base_num)

total_base_name = len(BASE_NAME)
print (total_base_name)

对于列表理解,语法变为:

count1 = sum([base1 == x for base1 in base1])
count2 = sum([base2 == x for base2 in base2])

标签: python-3.x

解决方案


(已编辑:我有点忽略了您使用的是 NumPy 数组)

取代:

sum(map(lambda base1 : base1 == x, base1)) 

或者:

count1 = sum([base1 == x for base1 in base1])

最好的方法取决于您的输入是一个list数组还是一个 NumPy 数组。

  • 如果你有一个list,你可以使用以下list.count()方法:
base1.count(x)
  • 如果你有一个 NumPy 数组,看起来就是这种情况,你可以使用np.count_nonzero()NumPy 数组:
import numpy as np


np.count_nonzero(base1 == x)

但是,这将创建一个可能很大的临时对象。这可以通过创建自己的函数并使用Cython(未显示)或更好的Numba加速它来解决,如下所示:

import numba as nb


@nb.jit
def nb_count_equal(arr, value):
    result = 0
    for x in arr:
        if x == value:
            result += 1
    return result

这也将比np.count_nonzero()在这种情况下更快。

在玩具数据上测试其中一些方法表明它们给出了相同的结果:

np.random.seed(0)  # to ensure reproducible results

arr = np.random.randint(0, 20, 1000)
y = 10

print(sum(map(lambda x: x == y, arr)))
# 41
print(sum([x == y for x in arr]))
# 41
print(np.count_nonzero(arr == y))
# 41
print(nb_count_equal(arr, y))
# 41

时间安排如下:

arr = np.random.randint(0, 20, 1000000)
y = 10

%timeit sum(map(lambda x: x == y, arr))
# 1 loop, best of 3: 2.54 s per loop
%timeit sum([x == y for x in arr])
# 1 loop, best of 3: 2.43 s per loop
%timeit np.count_nonzero(arr == y)
# 1000 loops, best of 3: 574 µs per loop
%timeit nb_count_equal(arr, y)
# 1000 loops, best of 3: 224 µs per loop

请注意,之前删除方括号以避免创建临时列表的建议比仅使用生成器要慢,因为sum()实现方式,但它肯定具有避免创建不必要的临时列表的优势。


最后,如果您要多次进行此计数,则一次执行此操作可能更有益np.unique()


推荐阅读