首页 > 解决方案 > What is an efficient way to find the minimum sum of multiple dictionary values, given keys of mutually exclusive integers?

问题描述

I have a dictionary, the keys of which consist of all len(4) combinations of the integers 0 to n, with all values being floats (representing a cost that was computed by another function).

e.g.:

cost_dict = {(0,1,2,3): 6.23, (0,1,2,4): 7.89,
            ...
            (14,15,16,17): 2.57}

I would like to efficiently find m mutually exclusive keys (that is, where the keys do not share any of their integers) whose values sum to the lowest number (thus, finding the lowest overall cost). That is, I don't just want the m minimum values of the dictionary, I want m mutually exclusive values that sum to the lowest value. (Or failing the absolute minimum, I wouldn't mind something efficient that comes pretty close).

So in the above example, for m = 3, maybe:

cost_dict[(0,3,5,11)]
>1.1 
cost_dict[(2,6,7,13)]
>0.24
cost_dict[(4,10,14,15)]
>3.91

... could be the keys whose values sum to the lowest possible value, of all mutually exclusively keys in this dictionary.

It may be possible that the smallest three values in the dict were something like:

cost_dict[(0,3,7,13)]
>0.5
cost_dict[(2,6,7,13)]
>0.24
cost_dict[(4,6,14,15)]
>0.8

But given the integers in these keys are not mutually exclusive, this would not be correct.


Is it possible to do better than O(n**m) time? That is, I could sum every item against every other item whose key is disjoint with the first (this would need the keys to be frozensets instead of tuples) for m levels. This is rather slow given the dictionary's length can be up to 10,000.

Something that seems to have helped me with an earlier version of this problem is creating a list of all possible combinations of keys, which is time-intensive, but potentially more efficient given that I will need to be finding the minimum cost numerous times.

标签: pythonperformancecombinations

解决方案


我尝试以三种不同的方式解决这个问题——优化的蛮力、动态编程方法和贪心算法。前两个无法处理 的输入n > 17,但生成了最优解,因此我可以使用它们来验证贪心方法的平均性能。我将首先从动态编程方法开始,然后描述贪婪的方法。

动态规划

首先,请注意,如果我们确定(1, 2, 3, 4)并且(5, 6, 7, 8)总和小于(3, 4, 5, 6)and (1, 2, 7, 8),那么您的最佳解决方案绝对不能同时包含(3, 4, 5, 6)and (1, 2, 7, 8)- 因为您可以将它们换成前者,并且总和更小。扩展这个逻辑,将会有一个 和 的最佳组合,(a, b, c, d)(e, f, g, h)会导致 的所有组合的和最小x0, x1, x2, x3, x4, x5, x6, x7,因此我们可以排除所有其他组合。

使用这些知识,我们可以通过暴力破解 的所有组合的总和,x0, x1, x2, x3, x4, x5, x6, x7将集合中的所有组合映射到它们的最小总和。然后,我们可以使用这些映射来重复from和pair 的过程。我们重复这个过程,直到我们获得 的所有最小和,然后我们对其进行迭代以找到最小和。[0, n)x0, x1, x2, x3x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11x0, x1, x2, x3, x4, x5, x6, x7x0, x1, x2, x3x0, x1 ... x_(4*m-1)

def dp_solve(const_dict, n, m):

    lookup = {comb:(comb,) for comb in const_dict.keys()}

    keys = set(range(n))
    for size in range(8, 4 * m + 1, 4):
        for key_total in combinations(keys, size):
            key_set = set(key_total)
            min_keys = (key_total[:4], key_total[4:])
            min_val = const_dict[min_keys[0]] + const_dict[min_keys[1]]

            key1, key2 = min(zip(combinations(key_total, 4), reversed(list(combinations(key_total, size - 4)))), key=lambda x:const_dict[x[0]]+const_dict[x[1]])

            k = tuple(sorted(x for x in key1 + key2))
            const_dict[k] = const_dict[key1] + const_dict[key2]
            lookup[k] = lookup[key1] + lookup[key2]

    key, val = min(((key, val) for key, val in const_dict.items() if len(key) == 4 * m), key=lambda x: x[1])
    return lookup[key], val

诚然,这个实现相当粗糙,因为我一直在不断地进行微优化,希望能够让它足够快,而不必切换到贪婪的方法。

贪婪的

这可能是您关心的,因为它可以快速处理相当大的输入,并且非常准确。

首先为部分总和构建一个列表,然后通过增加值开始迭代字典中的元素。对于每个元素,找到所有不与其键产生任何冲突的部分和,并将它们“组合”成一个新的部分和,并附加到列表中。在这样做的过程中,您构建了一个最小部分和列表,可以从k字典中的最小值创建。为了加快这一切,我使用哈希集来快速检查哪些部分和包含相同键的对。

在“快速”贪婪方法中,您将在找到密钥长度为4 * m(或等效地为m4 元组)的部分总和时中止。根据我的经验,这通常会产生相当好的结果,但如果需要,我想添加一些逻辑以使其更准确。为此,我添加了两个因素-

  • extra_runs- 这决定了在中断之前需要多少额外的迭代来寻找更好的解决方案
  • check_factor- 指定当前搜索“深度”的倍数,以向前扫描单个新整数,该整数为当前状态创建更好的解决方案。这与上面的不同之处在于它不会“保留”每个检查的新整数 - 它只会快速求和以查看它是否创建了一个新的最小值。这使得它明显更快,代价是其他m - 14 元组必须已经存在于部分和之一中。

结合起来,这些检查似乎总能找到真正的最小总和,代价是运行时间延长约 5 倍(尽管仍然相当快)。要禁用它们,只需通过0这两个因素。

def greedy_solve(const_dict, n, m, extra_runs=10, check_factor=2):
    pairs = sorted(const_dict.items(), key=lambda x: x[1])

    lookup = [set([]) for _ in range(n)]
    nset = set([])

    min_sums = []
    min_key, min_val = None, None
    for i, (pkey, pval) in enumerate(pairs):
        valid = set(nset)
        for x in pkey:
            valid -= lookup[x]
            lookup[x].add(len(min_sums))
        
        nset.add(len(min_sums))
        min_sums.append(((pkey,), pval))

        for x in pkey:
            lookup[x].update(range(len(min_sums), len(min_sums) + len(valid)))
        for idx in valid:
            comb, val = min_sums[idx]
            for key in comb:
                for x in key:
                    lookup[x].add(len(min_sums))
            nset.add(len(min_sums))
            min_sums.append((comb + (pkey,), val + pval))
            if len(comb) == m - 1 and (not min_key or min_val > val + pval):
                min_key, min_val = min_sums[-1]
        
        if min_key:
            if not extra_runs: break
            extra_runs -= 1

    for pkey, pval in pairs[:int(check_factor*i)]:
        valid = set(nset)
        for x in pkey:
            valid -= lookup[x]
        
        for idx in valid:
            comb, val = min_sums[idx]
            if len(comb) < m - 1:
                nset.remove(idx)
            elif min_val > val + pval:
                min_key, min_val = comb + (pkey,), val + pval
    return min_key, min_val

我对此进行了n < 36and测试m < 9,它似乎运行得相当快(最坏的情况下只需几秒钟即可完成)。我想它应该12 <= n <= 24很快就适用于您的情况。


推荐阅读