首页 > 解决方案 > 如何通过矢量化实现将索引列表转换为 numpy 中的单元格列表(numpy 列表数组)?

问题描述

单元列表是一种数据结构,用于维护 ND 网格中的数据点列表。例如,以下二维索引列表:

ind = [(0, 1), (1, 0), (0, 1), (0, 0), (0, 0), (0, 0), (1, 1)]

转换为以下 2x2 单元格列表:

cell = [[[3, 4, 5], [0, 2]],
        [[1, ],     [6, ]]
       ]

使用 O(n) 算法:

# create an empty 2x2 cell list
cell = [[[] for _ in range(2)] for _ in range(2)]
for i in range(len(ind)):
    cell[ind[i][0], ind[i][1]].append(i)

numpy 中是否有一种矢量化方式可以将索引列表 ( ind) 转换为上述单元结构?

标签: numpy

解决方案


我不认为有一个好的纯numpy,但你可以使用pythran或者——如果你不想接触编译器——scipy.sparse参见。这个问答本质上是您问题的一维版本。

[stb_pthr.py] 从将数组排序到索引数组指定的 bin 中的最有效方法简化?

import numpy as np

#pythran export sort_to_bins(int[:], int)

def sort_to_bins(idx, mx=-1):
    if mx==-1:
        mx = idx.max() + 1
    cnts = np.zeros(mx + 1, int)
    for i in range(idx.size):
        cnts[idx[i] + 1] += 1
    for i in range(1, cnts.size):
        cnts[i] += cnts[i-1]
    res = np.empty_like(idx)
    for i in range(idx.size):
        res[cnts[idx[i]]] = i
        cnts[idx[i]] += 1
    return res, cnts[:-1]

编译:pythran stb_pthr.py

主脚本:

import numpy as np
try:
    from stb_pthr import sort_to_bins
    HAVE_PYTHRAN = True
except:
    HAVE_PYTHRAN = False
from scipy import sparse

def fallback(flat, maxind):
    res = sparse.csr_matrix((np.zeros_like(flat),flat,np.arange(len(flat)+1)),
                            (len(flat), maxind)).tocsc()
    return res.indices, res.indptr[1:-1]

if not HAVE_PYTHRAN:
    sort_to_bins = fallback

def to_cell(data, shape=None):
    data = np.asanyarray(data)
    if shape is None:
        *shape, = (1 + c.max() for c in data.T)
    flat = np.ravel_multi_index((*data.T,), shape)
    reord, bnds = sort_to_bins(flat, np.prod(shape))
    return np.frompyfunc(np.split(reord, bnds).__getitem__, 1, 1)(
        np.arange(np.prod(shape)).reshape(shape))

ind = [(0, 1), (1, 0), (0, 1), (0, 0), (0, 0), (0, 0), (1, 1)]

print(to_cell(ind))

from timeit import timeit

ind = np.transpose(np.unravel_index(np.random.randint(0, 100, (1_000_000)), (10, 10)))

if HAVE_PYTHRAN:
    print(timeit(lambda: to_cell(ind), number=10)*100)
    sort_to_bins = fallback # !!! MUST REMOVE THIS LINE AFTER TESTING
print(timeit(lambda: to_cell(ind), number=10)*100)

示例运行,输出是 OP 的玩具示例的答案,以及 1,000,000 点示例的解决方案的时间(以毫秒为单位pythranscipy

[[array([3, 4, 5]) array([0, 2])]
 [array([1]) array([6])]]
11.411489499732852
29.700406698975712

推荐阅读