python - 查找具有指定数量的唯一元素 (n) 和指定总和 (s) 的一维数组的所有子集
问题描述
考虑一个 1D numpy 数组和两个常量,如下所示:
import numpy as np
arr = np.arange(60)
n = 5
s = 120
例如,arr 始终采用 [0,1,2,3,4, ... 59,60] 的形式。
问题:从一维数组(arr)中,我需要找到所有具有指定 sum 的完全n
UNIQUEs
元素的子集。解决方案可以像这样开始:
[[2, 10, 26, 35, 47],
[9, 14, 15, 40, 42],
etc...
我已经按顺序显示了行元素。这会很好,但这不是必需的。(即:组合,而不是排列)
目前,我通过使用相同表的叉积在 SQL 变体中处理此计算,每个表都包含arr
. 这可行,但非常慢,尤其是当 n 达到 12 左右时。
有没有办法在 Python/Numpy 中高效快速地做到这一点?
解决方案
递归解决方案如下。
这在o(len(arr)^(max(n, len(arr)-n)
.
但是,以下解决方案仍将比您的解决方案快得多。
import numpy as np
def perms(arr: np.array, n: int, s: int, used_inds: set):
if s == 0 and n == 0:
print(arr[np.array(list(used_inds))])
return
if s == 0: # not enough elements in group
return
if n == 0: # not reached sum too soon [all positive]
return
for i in range(len(arr)):
if i in used_inds:
continue
new_used_inds = set(used_inds)
new_used_inds.add(i)
perms(arr=arr, n=n - 1, s=s - arr[i], used_inds=new_used_inds)
def main():
arr = np.arange(20)
n = 3
s = 15
perms(arr=arr, n=n, s=s, used_inds=set())
if __name__ == "__main__":
main()
[ 0 1 14] [ 0 2 13] [ 0 3 12] [ 0 11 4] [ 0 10 5] [0 9 6] [0 8 7] [0 8 7] [0 9 6] [ 0 10 5] [ 0 11 4] [ 0 3 12] [ 0 2 13] [ 0 1 14] [ 0 1 14] [ 1 2 12] [11 1 3] [ 1 10 4] [1 5 9] [8 1 6] [8 1 6] [1 5 9] [ 1 10 4] [ 3 1 11] [ 1 2 12] [ 0 2 13] [ 1 2 12] [10 2 3] [9 2 4] [8 2 5] [2 6 7] [2 6 7] [8 2 5] [9 2 4] [ 3 2 10] [ 1 2 12] [ 0 3 12] [11 1 3] [10 2 3] [8 3 4] [3 5 7] [3 5 7] [8 3 4] [ 2 10 3] [11 1 3] [ 0 11 4] [ 1 10 4] [9 2 4] [8 3 4] [4 5 6] [4 5 6] [8 3 4] [9 2 4] [ 1 10 4] [ 0 10 5] [1 5 9] [8 2 5] [3 5 7] [4 5 6] [4 5 6] [3 5 7] [8 2 5] [9 5 1] [0 9 6] [8 1 6] [2 6 7] [4 5 6] [4 5 6] [2 6 7] [8 1 6] [0 8 7] [2 6 7] [3 5 7] [3 5 7] [2 6 7] [8 0 7] [8 1 6] [8 2 5] [8 3 4] [8 3 4] [8 2 5] [8 1 6] [0 9 6] [9 5 1] [9 2 4] [9 2 4] [9 5 1] [ 0 10 5] [ 1 10 4] [ 3 10 2] [ 2 10 3] [ 1 10 4] [ 0 11 4] [ 3 1 11] [ 3 1 11] [ 0 3 12] [ 1 2 12] [ 1 2 12] [ 0 2 13] [ 0 1 14]
这仍然很慢,但执行的计算比您提出的解决方案少得多,因为它切断了已经知道更多计算是无价的分支。
请注意,这会使 order 的可读性降低。
可读性稍差的代码,输出可读性更高:
import numpy as np
from collections import OrderedDict
def perms(arr: np.array, n: int, s: int, used_inds: OrderedDict):
if s == 0 and n == 0:
print(arr[np.array(list(used_inds))])
return
if s == 0: # not enough elements in group
return
if n == 0: # not reached sum too soon [all positive]
return
for i in range(len(arr)):
if i in used_inds:
continue
new_used_inds = OrderedDict(used_inds)
new_used_inds[i] = None
perms(arr=arr, n=n - 1, s=s - arr[i], used_inds=new_used_inds)
def main():
arr = np.arange(20)
n = 3
s = 15
perms(arr=arr, n=n, s=s, used_inds=OrderedDict())
if __name__ == "__main__":
main()
[ 0 1 14] [ 0 2 13] [ 0 3 12] [ 0 4 11] [ 0 5 10] [0 6 9] [0 7 8] [0 8 7] [0 9 6] [ 0 10 5] [ 0 11 4] [ 0 12 3] [ 0 13 2] [ 0 14 1] [ 1 0 14] [ 1 2 12] [ 1 3 11] [ 1 4 10] [1 5 9] [1 6 8] [1 8 6] [1 9 5] [ 1 10 4] [ 1 11 3] [ 1 12 2] [ 2 0 13] [ 2 1 12] [ 2 3 10] [2 4 9] [2 5 8] [2 6 7] [2 7 6] [2 8 5] [2 9 4] [ 2 10 3] [ 2 12 1] [ 3 0 12] [ 3 1 11] [ 3 2 10] [3 4 8] [3 5 7] [3 7 5] [3 8 4] [ 3 10 2] [ 3 11 1] [ 4 0 11] [ 4 1 10] [4 2 9] [4 3 8] [4 5 6] [4 6 5] [4 8 3] [4 9 2] [ 4 10 1] [ 5 0 10] [5 1 9] [5 2 8] [5 3 7] [5 4 6] [5 6 4] [5 7 3] [5 8 2] [5 9 1] [6 0 9] [6 1 8] [6 2 7] [6 4 5] [6 5 4] [6 7 2] [6 8 1] [7 0 8] [7 2 6] [7 3 5] [7 5 3] [7 6 2] [8 0 7] [8 1 6] [8 2 5] [8 3 4] [8 4 3] [8 5 2] [8 6 1] [9 0 6] [9 1 5] [9 2 4] [9 4 2] [9 5 1] [10 0 5] [10 1 4] [10 2 3] [10 3 2] [10 4 1] [11 0 4] [11 1 3] [11 3 1] [12 0 3] [12 1 2] [12 2 1] [13 0 2] [14 0 1]
您的解决方案相当于
from itertools import combinations
def combs(arr: np.array, n: int, s: int):
comb_generator = combinations(iterable=arr, r=n)
for comb in comb_generator:
total = sum(list(comb))
if total == s:
print(comb)
(0, 1, 14) (0, 2, 13) (0, 3, 12) (0, 4, 11) (0, 5, 10) (0, 6, 9) (0, 7, 8) (1, 2, 12) (1, 3, 11) (1, 4, 10) (1, 5, 9) (1, 6, 8) (2, 3, 10) (2, 4, 9) (2, 5, 8) (2, 6, 7) (3, 4, 8) (3, 5, 7) (4, 5, 6)
它不会按时切断计算,而只会迭代所有内容。
这至少不会分配额外的内存。
推荐阅读
- apache-spark-sql - Azure Databricks:使用 Spark SQL 的地理空间查询
- javascript - 在 Ramda 中重写这个函数
- selenium - presentOfElementLocated() 和 presentOfAllElementsLocatedBy() 之间的区别是 Selenium
- python-3.x - AWS Lambda 无法找到 app.handler(自定义 Docker 映像)
- c - 初始化然后从方法返回一个 char 数组
- hive - 方括号内的 Impala/Hive 字符串提取
- javascript - 覆盖特定的 materialUI 类
- ibm-cloud - IBM Watson Assistant:如果一个实体值存在于多个实体中应该怎么做
- r - 如何在 R 中添加图例
- java - 如何将自定义类从一个文件传输到另一个文件