首页 > 解决方案 > 树查找积大于阈值的列表的所有笛卡尔积

问题描述

让我们以这样的列表为例:

li=[[0.99, 0.002],
 [0.98, 0.0008, 0.0007],
 [0.97, 0.009, 0.001],
 [0.86, 0.001]]

请注意,每个子列表中的元素按降序排序,并且它们的总和始终小于或等于 1。此外,子列表本身按其第一个元素的降序排序。

我有兴趣找到组合,从每个子列表中获取一个元素,使得组合元素的乘积高于某个阈值,例如 1e-5。我发现这样做的一种方法是使用 itertools.product。

a = list(itertools.product(*li))
[item for item in a if np.prod(item)>1e-5]

但是,这个过程对我来说是不可行的,因为我的实际列表有太多的子列表,所以要检查的可能组合的数量太大。

我必须做相反的事情,即只找到满足给定条件的组合,而不是首先找到所有组合并检查阈值条件。例如:由于 0.002*0.0008*0.009 已经小于 1e-5,我可以忽略所有其他以 (0.002, 0.0008,0.009,...) 开头的组合。

我找不到一种简单的方法来实现这一点。我想到的是一个树数据结构,我在其中构建一棵树,这样每个节点都会跟踪产品,并且一旦节点值低于 1e-5,我就会停止在该节点上进一步构建树,并且在它右边的节点上(因为右边的节点将小于当​​前节点)。

一个简单的树骨架开始:

class Tree(object):
    def __init__(self, node=None):
        self.node = node
        self.children = []

    def add_child(self, child):
        self.children.append(child)

一旦构建了树,我将提取达到的组合depth = len(li)

在此处输入图像描述

任何帮助建立这样的树或解决问题的任何其他想法都将受到高度赞赏。谢谢!

标签: pythondata-structurestree

解决方案


因为您的项目及其子项目都已排序并且介于 0 和 1 之间,所以 itertools.product 的输出不会增加。数学。正如您指出的那样,这并不奇怪,但是您如何利用这一点...

我认为您想要的是 itertools.product 的副本,其中包含在产品低于阈值时立即修剪分支的快捷方式。这将允许您有效地迭代所有可能的匹配项,而不会浪费时间重新检查您已经知道不能满足阈值的产品。

我在这里找到了 itertools.product 的迭代器实现:how code a function similar to itertools.product in python 2.5(我使用的是 python 3,它似乎工作正常。)

所以我只是复制了它,并在循环内插入了阈值检查

# cutoff function
from functools import reduce
from operator import mul

threshold = 1e-5

def cutoff(args):
    if args:
        return reduce(mul, args) < threshold
    return False

# alternative implementation of itertools.product with cutoff
def product(*args, **kwds):
    def cycle(values, uplevel):
        for prefix in uplevel:       # cycle through all upper levels
            if cutoff(prefix):
                break
            for current in values:   # restart iteration of current level
                result = prefix + (current,)
                if cutoff(result):
                    break
                yield result

    stack = iter(((),))             
    for level in tuple(map(tuple, args)) * kwds.get('repeat', 1):
        stack = cycle(level, stack)  # build stack of iterators
    return stack

# your code here
li=[[0.99, 0.002],
    [0.98, 0.0008, 0.0007],
    [0.97, 0.009, 0.001],
    [0.86, 0.001]]

for a in product(*li):
    p = reduce(mul, a)
    print (p, a)

如果我省略截止值,我会得到相同的结果,然后稍后检查 p > 阈值。

(0.99, 0.98, 0.97, 0.86) 0.8093408399999998
(0.99, 0.98, 0.97, 0.001) 0.0009410939999999998
(0.99, 0.98, 0.009, 0.86) 0.007509348
(0.99, 0.98, 0.001, 0.86) 0.0008343719999999999
(0.99, 0.0008, 0.97, 0.86) 0.0006606864
(0.99, 0.0007, 0.97, 0.86) 0.0005781006
(0.002, 0.98, 0.97, 0.86) 0.0016350319999999998
(0.002, 0.98, 0.009, 0.86) 1.591703999999e-999


推荐阅读