首页 > 解决方案 > NumPy - 按频率对大型数组进行快速稳定的 arg 排序

问题描述

我有任何可比较的大型 1D NumPy数组,它的某些元素可能会重复。adtype

如何按降序/升序的值频率找到ix将稳定排序(此处描述的某种意义上的稳定性)的排序索引?a

我想找到最快和最简单的方法来做到这一点。也许有现有的标准 numpy 函数可以做到这一点。

这里还有另一个相关的问题,但它专门要求删除数组重复项,即只输出唯一的排序值,我需要原始数组的所有值,包括重复项。

我已经编写了我的第一个试验来完成这项任务,但它不是最快的(使用 Python 的循环)并且可能不是最短/最简单的可能形式。如果相等元素的重复率不高且数组很大,则此 python 循环可能非常昂贵。如果在 NumPy 中可用(例如 imaginary np.argsort_by_freq()),也可以使用简短的函数来完成这一切。

在线尝试!

import numpy as np
np.random.seed(1)
hi, n, desc = 7, 24, True
a = np.random.choice(np.arange(hi), (n,), p = (
    lambda p = np.random.random((hi,)): p / p.sum()
)())
us, cs = np.unique(a, return_counts = True)
af = np.zeros(n, dtype = np.int64)
for u, c in zip(us, cs):
    af[a == u] = c
if desc:
    ix = np.argsort(-af, kind = 'stable') # Descending sort
else:
    ix = np.argsort(af, kind = 'stable') # Ascending sort
print('rows: i_col(0) / original_a(1) / freqs(2) / sorted_a(3)')
print('    / sorted_freqs(4) / sorting_ix(5)')
print(np.stack((
    np.arange(n), a, af, a[ix], af[ix], ix,
), 0))

输出:

rows: i_col(0) / original_a(1) / freqs(2) / sorted_a(3)
    / sorted_freqs(4) / sorting_ix(5)
[[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
 [ 1  1  1  1  3  0  5  0  3  1  1  0  0  4  6  1  3  5  5  0  0  0  5  0]
 [ 7  7  7  7  3  8  4  8  3  7  7  8  8  1  1  7  3  4  4  8  8  8  4  8]
 [ 0  0  0  0  0  0  0  0  1  1  1  1  1  1  1  5  5  5  5  3  3  3  4  6]
 [ 8  8  8  8  8  8  8  8  7  7  7  7  7  7  7  4  4  4  4  3  3  3  1  1]
 [ 5  7 11 12 19 20 21 23  0  1  2  3  9 10 15  6 17 18 22  4  8 16 13 14]]

标签: pythonarraysnumpysortingfrequency

解决方案


我可能遗漏了一些东西,但似乎可以使用 aCounter然后根据元素值的计数对每个元素的索引进行排序,使用元素值,然后使用索引来打破关系。例如:

from collections import Counter

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

t = [(counts[v], v, i) for i, v in enumerate(a)]
t.sort()
print([v[2] for v in t])
t.sort(reverse=True)
print([v[2] for v in t])

输出:

[13, 14, 4, 8, 16, 6, 17, 18, 22, 0, 1, 2, 3, 9, 10, 15, 5, 7, 11, 12, 19, 20, 21, 23]
[23, 21, 20, 19, 12, 11, 7, 5, 15, 10, 9, 3, 2, 1, 0, 22, 18, 17, 6, 16, 8, 4, 14, 13]

如果要保持具有相同计数的组的索引的升序,则可以使用 lambda 函数进行降序:

t.sort(key = lambda x:(-x[0],-x[1],x[2]))
print([v[2] for v in t])

输出:

[5, 7, 11, 12, 19, 20, 21, 23, 0, 1, 2, 3, 9, 10, 15, 6, 17, 18, 22, 4, 8, 16, 14, 13]

如果要按照它们最初出现在数组中的顺序保持元素的顺序(如果它们的计数相同),那么不要对值进行排序,而是对它们在数组中第一次出现的索引进行排序:

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

idxs = {}
t = []
for i, v in enumerate(a):
    if not v in idxs:
        idxs[v] = i
    t.append((counts[v], idxs[v], i))

t.sort()
print([v[2] for v in t])
t.sort(key = lambda x:(-x[0],x[1],x[2]))
print([v[2] for v in t])

输出:

[13, 14, 4, 8, 16, 6, 17, 18, 22, 0, 1, 2, 3, 9, 10, 15, 5, 7, 11, 12, 19, 20, 21, 23]
[5, 7, 11, 12, 19, 20, 21, 23, 0, 1, 2, 3, 9, 10, 15, 6, 17, 18, 22, 4, 8, 16, 13, 14]

要根据计数排序,然后在数组中定位,您根本不需要值或第一个索引:

from collections import Counter

a = [ 1,  1,  1,  1,  3,  0,  5,  0,  3,  1,  1,  0,  0,  4,  6,  1,  3,  5,  5,  0,  0,  0,  5,  0]
counts = Counter(a)

t = [(counts[v], i) for i, v in enumerate(a)]
t.sort()
print([v[1] for v in t])
t.sort(key = lambda x:(-x[0],x[1]))
print([v[1] for v in t])

对于您的字符串数组,这将产生与示例数据的先前代码相同的输出:

a = ['g',  'g',  'c',  'f',  'd',  'd',  'g',  'a',  'a',  'a',  'f',  'f',  'f',
     'g',  'f',  'c',  'f',  'a',  'e',  'b',  'g',  'd',  'c',  'b',  'f' ]

这将产生输出:

[18, 19, 23, 2, 4, 5, 15, 21, 22, 7, 8, 9, 17, 0, 1, 6, 13, 20, 3, 10, 11, 12, 14, 16, 24]
[3, 10, 11, 12, 14, 16, 24, 0, 1, 6, 13, 20, 7, 8, 9, 17, 2, 4, 5, 15, 21, 22, 19, 23, 18]

推荐阅读