首页 > 解决方案 > 优化:具有最大值的受限整数分区

问题描述

使用下面的代码,我用每个分区中的数字来计算受限整数分区(每个数字在每个分区中只能出现一次)k,每个数字等于或大于1且不大于m。此代码会生成大量缓存值,以便快速耗尽内存。

例子:

sum := 15, k := 4, m:= 10预期的结果是6

具有以下受限整数分区:

1,2,3,9, 1,2,4,8, 1,2,5,7, 1,3,4,7, 1,3,5,7,2,3,4,6

public class Key{
  private final int sum;
  private final short k1;
  private final short start;
  private final short end;

  public Key(int sum, short k1, short start, short end){
    this.sum = sum;
    this.k1 = k1;
    this.start = start;
    this.end = end;
  }
  // + hashcode and equals
}

public BigInteger calcRestrictedIntegerPartitions(int sum,short k,short m){
  return calcRestrictedIntegerPartitionsHelper(sum,(short)0,k,(short)1,m,new HashMap<>());
}

private BigInteger calcRestrictedIntegerPartitionsHelper(int sum, short k1, short k, short start, short end, Map<Key,BigInteger> cache){
  if(sum < 0){
    return BigInteger.ZERO;
  }
  if(k1 == k){
    if(sum ==0){
      return BigInteger.ONE;
    }
    return BigInteger.ZERO;
  }
  if(end*(k-k1) < sum){
    return BigInteger.ZERO;
  }

  final Key key = new Key(sum,(short)(k-k1),start,end);

  BigInteger fetched = cache.get(key);

  if(fetched == null){
    BigInteger tmp = BigInteger.ZERO;

    for(short i=start; i <= end;i++){
      tmp = tmp.add(calcRestrictedIntegerPartitionsHelper(sum-i,(short)(k1+1),k,(short)(i+1),end,cache));
    }

    cache.put(key, tmp);
    return tmp;
  }

  return fetched;
}

是否有避免/减少缓存的公式?或者我如何计算受限整数部分k and m

标签: javamathcombinationsmathematical-optimizationinteger-partition

解决方案


您的密钥包含 4 个部分,因此哈希空间可能会达到这些部分的最大值乘积的值。可以使用反向循环和零值作为自然限制将密钥减少到 3 个部分。

lru_cachePython 示例使用哈希表大小 =的内置功能N*K*M

@functools.lru_cache(250000)
def diff_partition(N, K, M):
    '''Counts integer partitions of N with K distint parts <= M'''
    if K == 0:
        if N == 0:
            return 1
        return 0
    res = 0
    for i in range(min(N, M), -1, -1):
        res += diff_partition(N - i, K - 1, i - 1)
    return res

def diffparts(Sum, K, M):   #diminish problem size allowing zero part
    return diff_partition(Sum - K, K, M-1)

print(diffparts(500, 25, 200))

>>>147151784574

推荐阅读