首页 > 解决方案 > Numpy:如何有效地获取每行的前 N ​​个元素?

问题描述

我尝试获取每行的 (index, value) 元组的 topN 列表。

heapq在处理过程中只使用单核,然后我尝试使用多处理但时间消耗更长。

有没有更快的方法来获得结果?

谢谢

import heapq
import multiprocessing
import numpy
import time


class C1:

    def __init__(self):
        self.data = numpy.random.rand(100, 50000)
        self.top_n = 5000

    def run_normal(self):
        output = []
        for item_index in range(self.data.shape[0]):
            objs = heapq.nlargest(self.top_n, enumerate(self.data[item_index]), lambda x: x[1])
            output.append(objs)

    def run_mp(self):
        with multiprocessing.Pool() as pool:
            output = pool.map(self.sort_arr, self.data.tolist())

    def sort_arr(self, arr):
        return heapq.nlargest(self.top_n, enumerate(arr), lambda x: x[1])


if __name__ == '__main__':
    c1 = C1()

    start = time.time()
    c1.run_normal()
    print(time.time() - start)

    start = time.time()
    c1.run_mp()
    print(time.time() - start)

输出

3.2407033443450928 # for-loop time
12.387788534164429 # multiprocessing time

标签: pythonnumpysorting

解决方案


明确说明问题:

我们得到一个包含数据点的 M x N numpy 数组。我们想要获得一个 M xk,其中每一行包含原始数组中的前 k 个值,并与原始行中的值索引配对。

例如:对于 [[1, 2], [4, 3], [5, 6]] 和 k = 1 的输入,我们想输出 [[(0, 1)], [(1, 3 )]、[(0, 5)]]。

解决方案

最好和最快的解决方案是使用本机 numpy 功能。策略是首先获取每行的顶部索引,然后从这些索引中获取元素,然后将两者组合到我们的输出数组中。

data = np.random(100, 50000)  # large
k = 5

# Define the type of our output array elements: (int, float)
dt = np.dtype([('index', np.int32, 1), ('value', np.float64, 1)])

# Take the indices of the largest k elements from each row
top_k_inds = np.argsort(data)[:, -1:-k - 1:-1]

# Take the values at those indices
top_k = np.take_along_axis(data, top_k_inds, axis=-1)

# Stack the two together along a third axis (to get index-value pairs)
top_k_pairs = np.stack((top_k_inds, top_k), axis=2)

# Convert the type (otherwise we have the indices as floats)
top_k_pairs = top_k_pairs.astype(dt)

推荐阅读