首页 > 解决方案 > 计算求解方程的可能方法的数量

问题描述

这个问题是在 HackerEarth 的一个挑战中提出的:

马克正在解决一个有趣的问题。他需要找出许多不同的方式,例如
(i + 2*j+ k) % (x + y + 2*z) = 0,其中1 <= i,j,k,x,y,z <= N

帮他找。

约束:

1<= T<= 10

1<=N<= 1000

输入格式:

第一行包含 T,测试用例的数量。每个测试用例都包含一个整数,N 在单独的行中。

输出格式:

对于每个测试用例,在单独的行中输出不同方式的数量。

样本输入
2

1

2

样本输出

1

15

解释

在第一种情况下,唯一可能的方法是 i = j = k = x =y = z = 1

我没有任何方法可以解决这个问题,我已经尝试过一个,我知道它甚至不接近这个问题。

import random

def CountWays (N):

    # Write your code here
    i = random.uniform(1,N)
    j = random.uniform(1,N)
    k = random.uniform(1,N)
    x = random.uniform(1,N)
    y = random.uniform(1,N)
    z = random.uniform(1,N)
    d = 0
    for i in range(N):
        if (i+2*j+k)%(x+y+2*z)==0:
            d += 1
    return d

T = int(input())

for _ in range(T):

    N = int(input())

    out_ = CountWays(N)
    print (out_)

我的输出

0

0

相反,它应该给出输出

1

15

标签: python-3.xalgorithm

解决方案


分子 ( num) 的取值范围为 4 到 4N。分母 ( dom) 的取值范围为 4 到num。您可以将您的问题拆分为两个较小的问题:1)分子的给定值可以被多少分母的值整除?2) 给定的分母和分子有多少种构造方式?

要回答 1),我们可以简单地遍历分子的所有可能值,然后遍历分母 where 的所有值numerator % denominator == 0。回答2)我们可以找到满足等式和约束的分子和分母的所有分区。构造给定分子和分母的方法的数量将是每个分区数的乘积。

import itertools

def divisible_numbers(n):
    """
    Get all numbers with which n is divisible.
    """
    for i in range(1,n+1):
        if n % i == 0:
            yield i
        if i >= n:
            break

def get_partitions(n):
    """
    Generate ALL ways n can be partitioned into 3 integers.
    Modified from http://code.activestate.com/recipes/218332-generator-for-integer-partitions/#c9
    """

    a = [1]*n
    y = -1
    v = n
    while v > 0:
        v -= 1
        x = a[v] + 1
        while y >= 2 * x:
            a[v] = x
            y -= x
            v += 1
        w = v + 1
        while x <= y:
            a[v] = x
            a[w] = y
            if w == 2:
                yield a[:w + 1]
            x += 1
            y -= 1
        a[v] = x + y
        y = a[v] - 1
        if w == 3:
            yield a[:w]

def get_number_of_valid_partitions(num, N):
    """
    Get number of valid partitions of num, given that
    num = i + j + 2k, and that 1<=i,j,k<=N
    """
    n = 0
    for partition in get_partitions(num):
        # This can be done a bit more cleverly, but makes
        # the code extremely complicated to read, so
        # instead we just brute force the 6 combinations,
        # ignoring non-unique permutations using a set
        for i,j,k in set(itertools.permutations(partition)):
            if i <= N and j <= N and k <= 2*N and k % 2 == 0:
                n += 1
    return n

def get_number_of_combinations(N):
    """
    Get number of ways the equality can be solved under the given constraints
    """
    out = 0
    # Create a dictionary of number of valid partitions
    # for all numerator values we can encounter
    n_valid_partitions = {i: get_number_of_valid_partitions(i, N) for i in range(1,4*N+1)}

    for numerator in range(4,4*N+1):
        numerator_permutations = n_valid_partitions[numerator]
        for denominator in divisible_numbers(numerator):
            denominator_permutations = n_valid_partitions[denominator]
            if denominator < 4:
                continue
            out += numerator_permutations * denominator_permutations
    return out

N = 2
out = get_number_of_combinations(N)
print(out)

get_partitions由于函数和get_number_of_valid_partitions函数的交互方式,现在代码的扩展性非常差。

编辑

下面的代码要快得多。有一个小的改进divisible_numbers,但主要的加速在于get_number_of_valid_partitions不会创建不必要的临时列表,因为它现在已经加入get_partitions到一个函数中。其他大的加速来自使用numba. 的代码get_number_of_valid_partitions现在几乎无法阅读,因此我添加了一个更简单但速度稍慢的版本,命名为get_number_of_valid_partitions_simple这样您就可以了解复杂函数中发生了什么。

import numba

@numba.njit
def divisible_numbers(n):
    """
    Get all numbers with which n is divisible.
    Modified from·
    """
    # We can save some time by only looking at
    # values up to n/2
    for i in range(4,n//2+1):
        if n % i == 0:
            yield i
    yield n

def get_number_of_combinations(N):
    """
    Get number of ways the equality can be solved under the given constraints
    """
    out = 0
    # Create a dictionary of number of valid partitions
    # for all numerator values we can encounter
    n_valid_partitions = {i: get_number_of_valid_partitions(i, N) for i in range(4,4*N+1)}

    for numerator in range(4,4*N+1):
        numerator_permutations = n_valid_partitions[numerator]
        for denominator in divisible_numbers(numerator):
            if denominator < 4:
                continue
            denominator_permutations = n_valid_partitions[denominator]
            out += numerator_permutations * denominator_permutations
    return out

@numba.njit
def get_number_of_valid_partitions(num, N):
    """
    Get number of valid partitions of num, given that
    num = i + j + 2l, and that 1<=i,j,l<=N.
    """
    count = 0
    # In the following, k = 2*l
    #There's different cases for i,j,k that we can treat separately
    # to give some speedup due to symmetry.
    #i,j can be even or odd. k <= N or N < k <= 2N.

    # Some combinations only possible if num is even/odd
    # num is even
    if num % 2 == 0:
        # i,j odd, k <= 2N
        k_min = max(2, num - 2 * (N - (N + 1) % 2))
        k_max = min(2 * N, num - 2)
        for k in range(k_min, k_max + 1, 2):
            # only look at i<=j
            i_min = max(1, num - k - N + (N + 1) % 2)
            i_max = min(N, (num - k)//2)
            for i in range(i_min, i_max + 1, 2):
                j = num - i - k
                # if i == j, only one permutations
                # otherwise two due to symmetry
                if i == j:
                    count += 1
                else:
                    count += 2

        # i,j even, k <= N
        # only look at k<=i<=j
        k_min = max(2, num - 2 * (N - N % 2))
        k_max = min(N, num // 3)
        for k in range(k_min, k_max + 1, 2):
            i_min = max(k, num - k - N + N % 2)
            i_max = min(N, (num - k) // 2)
            for i in range(i_min, i_max + 1, 2):
                j = num - i - k
                if i == j == k:
                    # if i == j == k, only one permutation
                    count += 1
                elif i == j or i == k or j == k:
                    # if only two of i,j,k are the same there are 3 permutations
                    count += 3
                else:
                    # if all differ, there are six permutations
                    count += 6

        # i,j even, N < k <= 2N
        k_min = max(N + 1 + (N + 1) % 2, num - 2 * N)
        k_max = min(2 * N, num - 4)
        for k in range(k_min, k_max + 1, 2):
            # only look for i<=j
            i_min = max(2, num - k - N + 1 - (N + 1) % 2)
            i_max = min(N, (num - k) // 2)
            for i in range(i_min, i_max + 1, 2):
                j = num - i - k
                if i == j:
                    # if i == j, only one permutation
                    count += 1
                else:
                    # if all differ, there are two permutations
                    count += 2
    # num is odd
    else:
        # one of i,j is even, the other is odd. k <= N
        # We assume that j is odd, k<=i and correct the symmetry in the counts
        k_min = max(2, num - 2 * N + 1)
        k_max = min(N, (num - 1) // 2)
        for k in range(k_min, k_max + 1, 2):
            i_min = max(k, num - k - N + 1 - N % 2)
            i_max = min(N, num - k - 1)
            for i in range(i_min, i_max + 1, 2):
                j = num - i - k
                if i == k:
                    # if i == j, two permutations
                    count += 2
                else:
                    # if i and k differ, there are four permutations
                    count += 4

        # one of i,j is even, the other is odd. N < k <= 2N
        # We assume that j is odd and correct the symmetry in the counts
        k_min = max(N + 1 + (N + 1) % 2, num - 2 * N + 1)
        k_max = min(2 * N, num - 3)
        for k in range(k_min, k_max + 1, 2):
            i_min = max(2, num - k - N + (N + 1) % 2)
            i_max = min(N, num - k - 1)
            for i in range(i_min, i_max + 1, 2):
                j = num - i - k
                count += 2

    return count

@numba.njit
def get_number_of_valid_partitions_simple(num, N):
    """
    Simpler but slower version of 'get_number_of_valid_partitions'
    """
    count = 0

    for k in range(2, 2 * N + 1, 2):
        for i in range(1, N + 1):
            j = num - i - k
            if 1 <= j <= N:
                count += 1
    return count

if __name__ == "__main__":
    N = int(sys.argv[1])
    out = get_number_of_combinations(N)
    print(out)

推荐阅读