首页 > 解决方案 > 提高 np.where() 循环的性能

问题描述

我正在尝试在不进行多处理的情况下加快以下脚本的代码(理想情况下 > 4x)。在未来的步骤中,我将实现多处理,但是即使我将其拆分为 40 个内核,当前的速度也太慢了。因此,我尝试先优化代码。

import numpy as np

def loop(x,y,q,z):
    matchlist = []
    for ind in range(len(x)):
        matchlist.append(find_match(x[ind],y[ind],q,z))
    return matchlist

def find_match(x,y,q,z):
    A = np.where(q == x)
    B = np.where(z == y)
    return np.intersect1d(A,B)


# N will finally scale up to 10^9
N = 1000
M = 300

X = np.random.randint(M, size=N)
Y = np.random.randint(M, size=N)

# Q and Z size is fixed at 120000
Q = np.random.randint(M, size=120000)
Z = np.random.randint(M, size=120000)

# convert int32 arrays to str64 arrays, to represent original data (which are strings and not numbers)
X = np.char.mod('%d', X)
Y = np.char.mod('%d', Y)
Q = np.char.mod('%d', Q)
Z = np.char.mod('%d', Z)

matchlist = loop(X,Y,Q,Z)

我有两个长度相同的数组(X 和 Y)。这些阵列的每一行对应一个 DNA 测序读数(基本上是字母“A”、“C”、“G”、“T”的字符串;细节与此处的示例代码无关)。

我还有两个长度相同的“参考数组”(Q 和 Z),我想找到 Q 中 X 的每个元素以及 Y 中的每个元素的出现(使用 np.where()) Z(基本上是 find_match() 函数)。之后我想知道为 X 和 Y 找到的索引之间是否存在重叠/相交。

示例输出(匹配列表;某些 X/Y 行在 Q/Y 中有匹配的索引,有些则没有,例如索引 11):

匹配列表的示例输出

到目前为止,该代码运行良好,但在N=10^9的最终数据集上执行需要很长时间(在此代码示例中,N=1000 以使测试更快)。在我的笔记本电脑上执行 1000 行 X/Y 需要大约 2.29 秒:

loop() 的 timeit 测试

每次find_match()执行大约需要 2.48 毫秒,大约是最终循环的 1/1000。

find_match() 的 timeit 测试

第一种方法是将 (x 与 y) 以及 (q 与 z) 结合起来,然后我只需要运行np.where()一次,但我还不能让它工作。

我尝试在 Pandas ( .loc()) 中循环和查找,但这比np.where().

这个问题与 philshem 最近提出的一个问题密切相关(将几个 NumPy “where”语句组合成一个以提高性能),但是,在这个问题上提供的解决方案不适用于我在这里的方法。

标签: pythonnumpyperformance

解决方案


Numpy 在这里用处不大,因为您需要的是查找锯齿状数组,并以字符串作为索引。

lookup = {}
for i, (q, z) in enumerate(zip(Q, Z)):
    lookup.setdefault((q, z), []).append(i)

matchlist = [lookup.get((x, y), []) for x, y in zip(X, Y)]

如果您不需要将输出作为锯齿状数组,但只需要一个表示存在的布尔值就可以了,并且可以将每个字符串预处理为一个数字,那么有一种更快的方法。

lookup = np.zeros((300, 300), dtype=bool)
lookup[Q, Z] = True
matchlist = lookup[X, Y]

您通常不希望使用此方法来替换以前的锯齿状情况,因为密集变体(例如 Daniel F 的解决方案)将导致内存效率低下,并且 numpy 不能很好地支持稀疏数组。但是,如果需要更高的速度,那么稀疏解决方案当然是可能的。


推荐阅读