首页 > 解决方案 > 查找具有指定数量的唯一元素 (n) 和指定总和 (s) 的一维数组的所有子集

问题描述

考虑一个 1D numpy 数组和两个常量,如下所示:

import numpy as np

arr = np.arange(60)

n = 5
s = 120

例如,arr 始终采用 [0,1,2,3,4, ... 59,60] 的形式。

问题:从一维数组(arr)中,我需要找到所有具有指定 sum 的完全n UNIQUEs元素的子集。解决方案可以像这样开始:

          [[2, 10, 26, 35, 47], 
           [9, 14, 15, 40, 42],
           etc...

我已经按顺序显示了行元素。这会很好,但这不是必需的。(即:组合,而不是排列)

目前,我通过使用相同表的叉积在 SQL 变体中处理此计算,每个表都包含arr. 这可行,但非常慢,尤其是当 n 达到 12 左右时。

有没有办法在 Python/Numpy 中高效快速地做到这一点?

标签: pythonnumpy

解决方案


递归解决方案如下。

这在o(len(arr)^(max(n, len(arr)-n).

我会回到这个...

但是,以下解决方案仍将比您的解决方案快得多。

import numpy as np


def perms(arr: np.array, n: int, s: int, used_inds: set):
    if s == 0 and n == 0:
        print(arr[np.array(list(used_inds))])
        return

    if s == 0:  # not enough elements in group
        return

    if n == 0:  # not reached sum too soon [all positive]
        return

    for i in range(len(arr)):
        if i in used_inds:
            continue
        new_used_inds = set(used_inds)
        new_used_inds.add(i)
        perms(arr=arr, n=n - 1, s=s - arr[i], used_inds=new_used_inds)


def main():
    arr = np.arange(20)

    n = 3
    s = 15

    perms(arr=arr, n=n, s=s, used_inds=set())


if __name__ == "__main__":
    main()
[ 0  1 14]
[ 0  2 13]
[ 0  3 12]
[ 0 11  4]
[ 0 10  5]
[0 9 6]
[0 8 7]
[0 8 7]
[0 9 6]
[ 0 10  5]
[ 0 11  4]
[ 0  3 12]
[ 0  2 13]
[ 0  1 14]
[ 0  1 14]
[ 1  2 12]
[11  1  3]
[ 1 10  4]
[1 5 9]
[8 1 6]
[8 1 6]
[1 5 9]
[ 1 10  4]
[ 3  1 11]
[ 1  2 12]
[ 0  2 13]
[ 1  2 12]
[10  2  3]
[9 2 4]
[8 2 5]
[2 6 7]
[2 6 7]
[8 2 5]
[9 2 4]
[ 3  2 10]
[ 1  2 12]
[ 0  3 12]
[11  1  3]
[10  2  3]
[8 3 4]
[3 5 7]
[3 5 7]
[8 3 4]
[ 2 10  3]
[11  1  3]
[ 0 11  4]
[ 1 10  4]
[9 2 4]
[8 3 4]
[4 5 6]
[4 5 6]
[8 3 4]
[9 2 4]
[ 1 10  4]
[ 0 10  5]
[1 5 9]
[8 2 5]
[3 5 7]
[4 5 6]
[4 5 6]
[3 5 7]
[8 2 5]
[9 5 1]
[0 9 6]
[8 1 6]
[2 6 7]
[4 5 6]
[4 5 6]
[2 6 7]
[8 1 6]
[0 8 7]
[2 6 7]
[3 5 7]
[3 5 7]
[2 6 7]
[8 0 7]
[8 1 6]
[8 2 5]
[8 3 4]
[8 3 4]
[8 2 5]
[8 1 6]
[0 9 6]
[9 5 1]
[9 2 4]
[9 2 4]
[9 5 1]
[ 0 10  5]
[ 1 10  4]
[ 3 10  2]
[ 2 10  3]
[ 1 10  4]
[ 0 11  4]
[ 3  1 11]
[ 3  1 11]
[ 0  3 12]
[ 1  2 12]
[ 1  2 12]
[ 0  2 13]
[ 0  1 14]

这仍然很慢,但执行的计算比您提出的解决方案少得多,因为它切断了已经知道更多计算是无价的分支。

请注意,这会使 order 的可读性降低。


可读性稍差的代码,输出可读性更高:

import numpy as np
from collections import OrderedDict


def perms(arr: np.array, n: int, s: int, used_inds: OrderedDict):
    if s == 0 and n == 0:
        print(arr[np.array(list(used_inds))])
        return

    if s == 0:  # not enough elements in group
        return

    if n == 0:  # not reached sum too soon [all positive]
        return

    for i in range(len(arr)):
        if i in used_inds:
            continue
        new_used_inds = OrderedDict(used_inds)
        new_used_inds[i] = None
        perms(arr=arr, n=n - 1, s=s - arr[i], used_inds=new_used_inds)


def main():
    arr = np.arange(20)

    n = 3
    s = 15

    perms(arr=arr, n=n, s=s, used_inds=OrderedDict())


if __name__ == "__main__":
    main()
[ 0  1 14]
[ 0  2 13]
[ 0  3 12]
[ 0  4 11]
[ 0  5 10]
[0 6 9]
[0 7 8]
[0 8 7]
[0 9 6]
[ 0 10  5]
[ 0 11  4]
[ 0 12  3]
[ 0 13  2]
[ 0 14  1]
[ 1  0 14]
[ 1  2 12]
[ 1  3 11]
[ 1  4 10]
[1 5 9]
[1 6 8]
[1 8 6]
[1 9 5]
[ 1 10  4]
[ 1 11  3]
[ 1 12  2]
[ 2  0 13]
[ 2  1 12]
[ 2  3 10]
[2 4 9]
[2 5 8]
[2 6 7]
[2 7 6]
[2 8 5]
[2 9 4]
[ 2 10  3]
[ 2 12  1]
[ 3  0 12]
[ 3  1 11]
[ 3  2 10]
[3 4 8]
[3 5 7]
[3 7 5]
[3 8 4]
[ 3 10  2]
[ 3 11  1]
[ 4  0 11]
[ 4  1 10]
[4 2 9]
[4 3 8]
[4 5 6]
[4 6 5]
[4 8 3]
[4 9 2]
[ 4 10  1]
[ 5  0 10]
[5 1 9]
[5 2 8]
[5 3 7]
[5 4 6]
[5 6 4]
[5 7 3]
[5 8 2]
[5 9 1]
[6 0 9]
[6 1 8]
[6 2 7]
[6 4 5]
[6 5 4]
[6 7 2]
[6 8 1]
[7 0 8]
[7 2 6]
[7 3 5]
[7 5 3]
[7 6 2]
[8 0 7]
[8 1 6]
[8 2 5]
[8 3 4]
[8 4 3]
[8 5 2]
[8 6 1]
[9 0 6]
[9 1 5]
[9 2 4]
[9 4 2]
[9 5 1]
[10  0  5]
[10  1  4]
[10  2  3]
[10  3  2]
[10  4  1]
[11  0  4]
[11  1  3]
[11  3  1]
[12  0  3]
[12  1  2]
[12  2  1]
[13  0  2]
[14  0  1]

您的解决方案相当于

from itertools import combinations


def combs(arr: np.array, n: int, s: int):
    comb_generator = combinations(iterable=arr, r=n)
    for comb in comb_generator:
        total = sum(list(comb))
        if total == s:
            print(comb)
(0, 1, 14)
(0, 2, 13)
(0, 3, 12)
(0, 4, 11)
(0, 5, 10)
(0, 6, 9)
(0, 7, 8)
(1, 2, 12)
(1, 3, 11)
(1, 4, 10)
(1, 5, 9)
(1, 6, 8)
(2, 3, 10)
(2, 4, 9)
(2, 5, 8)
(2, 6, 7)
(3, 4, 8)
(3, 5, 7)
(4, 5, 6)

它不会按时切断计算,而只会迭代所有内容。
这至少不会分配额外的内存。


推荐阅读