首页 > 技术文章 > 「雅礼集训2018」Magic

zsbzsb 2020-06-11 21:37 原文

传送门

看到恰好 \(k\) 对这样的限制,我们考虑容斥:设 \(g_k\) 表示至少有 \(k\) 对的方案,\(ans_k\) 表示恰好有 \(k\) 对的方案:

\[g_k = \sum_{i = k} ^ {n - 1} {i \choose k} ans_i \Rightarrow ans_k = \sum_{i = k} ^ {n - 1} (-1) ^ {i - k} { i \choose k } g_i \]

然后我们考虑如何计算 \(g_k\)

发现同种类型的牌都是一样的,不好算方案,我们不妨先认为同种类型的牌互不相同,那么在最后的答案中我们除以 \(\prod_{i = 1} ^ m a_i !\) 就好了(可重集排列)

我们考虑 \(\text{DP}\) :设 \(dp_{i, j}\) 表示用前 \(i\) 种类型的牌,组成了至少 \(j\) 对的方案,为了方便转移,我们用 \(f_{i, j}\) 表示用第 \(i\) 种类型的牌,组成了至少 \(j\) 对的方案。

考虑 \(f_{i, j}\) 怎么计算,由于我们之前已经钦定了每张牌都不相同,所以我们可以先从 \(a_i\) 中牌中选 \(a_i - j\) 张作为不直接参与答案计算的牌,然后考虑把剩下的 \(j\) 张插入到这些牌的左边,显然这样至少会有 \(j\) 对。那么对于这 \(j\) 张牌,第一张有 \(a_i - j\) 张牌可以插,第二张有 \(a_i - j + 1\) 张可以插,以此类推不难得到 \(f_{i, j}\) 的表达式:

\[f_{i, j} = {a_i \choose a_i - j} (a_i - 1)^{\underline j} \]

转移方程很显然:\(dp_{i, j} = \sum_{k = 0} ^ {j} dp_{i - 1, j - k} \times f_{i, k}\)

不难做到用分治 \(\text{NTT}\)(雾)优化求 \(dp_{i, j}\) 的过程。

但是这里要注意一点就是我们求出来的 \(dp_{m, k}\)\(g_k\) 是不等价的,因为我们对于和魔术对无关的 \(n - k\) 张牌是可以随便排的,也就是说 \(g_k = dp_{m, k} \times (n - k)!\)

最后我们就可以直接计算 \(ans_k\) 了。

参考代码:

#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;

const int _ = 4e5 + 5, p = 998244353, G = 3, iG = 332748118;

template < class T > void read(T& s) {
    s = 0; int f = 0; char c = getchar();
    while ('0' > c || c > '9') f |= c == '-', c = getchar();
    while ('0' <= c && c <= '9') s = s * 10 + c - 48, c = getchar();
    s = f ? -s : s;
}

int n, m, k, a[_], fac[_], ifc[_], r[_];
vector < int > f[_], g;

int power(int x, int k) {
    int res = 1;
    for (; k; k >>= 1, x = 1ll * x * x % p)
        if (k & 1) res = 1ll * res * x % p;
    return res % p;
}

int C(int N, int M) { return 1ll * fac[N] * ifc[M] % p * ifc[N - M] % p; }

void NTT(vector < int > & A, int N, int type) {
    for (int i = 0; i < N; ++i) if (i < r[i]) swap(A[i], A[r[i]]);
    for (int i = 1; i < N; i <<= 1) {
        int Wn = power(type ? G : iG, (p - 1) / (i << 1));
        for (int j = 0; j < N; j += i << 1)
            for (int k = 0, w = 1; k < i; ++k, w = 1ll * w * Wn % p) {
                int x = A[j + k], y = 1ll * w * A[j + i + k] % p;
                A[j + k] = (x + y) % p, A[j + i + k] = (x - y + p) % p;
            }
    }
    if (!type) {
        int inv = power(N, p - 2);
        for (int i = 0; i < N; ++i) A[i] = 1ll * A[i] * inv % p;
    }
}

vector < int > solve(int L, int R) {
    if (L == R) return f[L];
    int mid = (L + R) >> 1;
    vector < int > A = solve(L, mid);
    vector < int > B = solve(mid + 1, R);
    int N = 1, l = 0;
    while (N <= A.size() + B.size()) N <<= 1, ++l;
    for (int i = 0; i < N; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    while (A.size() != N) A.push_back(0);
    while (B.size() != N) B.push_back(0);
    NTT(A, N, 1), NTT(B, N, 1);
    for (int i = 0; i < N; ++i) A[i] = 1ll * A[i] * B[i] % p;
    NTT(A, N, 0);
    while (A.size() && !A.back()) A.pop_back();
    return A;
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
    read(m), read(n), read(k);
    for (int i = 1; i <= m; ++i) read(a[i]);
    fac[0] = 1;
    for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % p;
    ifc[n] = power(fac[n], p - 2);
    for (int i = n; i; --i) ifc[i - 1] = 1ll * ifc[i] * i % p;
    for (int i = 1; i <= m; ++i)
        for (int j = 0; j < a[i]; ++j)
            f[i].push_back(1ll * C(a[i], j) * fac[a[i] - 1] % p * ifc[a[i] - j - 1] % p);
    g = solve(1, m);
    for (int i = 0; i < g.size(); ++i) g[i] = 1ll * g[i] * fac[n - i] % p;
    int ans = 0;
    for (int i = k; i < g.size(); ++i) {
        int tmp = 1ll * C(i, k) * g[i] % p;
        if (i - k & 1) ans = (ans - tmp + p) % p;
        else ans = (ans + tmp) % p;
    }
    for (int i = 1; i <= m; ++i) ans = 1ll * ans * ifc[a[i]] % p;
    printf("%d\n", ans);
    return 0;
}

推荐阅读