首页 > 技术文章 > 「清华集训」小 Y 和恐怖的奴隶主

zsbzsb 2020-05-09 17:02 原文

传送门
观察到 \(m\) 的值很小,考虑把不同血量的随从的个数计入状态。
\(dp_{i, A, B, C}\) 表示在第 \(i\) 次攻击之后,还剩 \(A\) 个一血怪,\(B\) 个二血怪,\(C\) 个三血怪的概率。
转移很显然,只需要注意生成新怪的情况即可。
但是这对于 \(N \le 10^{18}\) 来说是完全不行的。
我们考虑矩阵加速。
首先我们对每一个合法的三元组 \((A, B, C)\) 进行编号(可以发现可能的最大值是 \(166\))。
然后我们把这些编号看做矩阵的行列,构造转移矩阵。
简单的说,我们记录状态之间转移的概率,然后再新增一行、一列来记录期望。
对于一种局面 \((A, B, C)\),它对答案的贡献是 \(\frac{dp_{A, B, C}}{A + B + C + 1}\)
然后我们就可以快乐的矩阵快速幂了…………才怪。
分析一下复杂度,发现每次都跑一次快速幂是跑不过的。
所以我们不妨预处理 \(dp_i\) 表示 \(2 ^ i\) 个转移矩阵的乘积,然后每次只要用一个行向量乘上 \(\log\)\(dp_i\) 就可以了。
参考代码:

#include <cstdio>

const int p = 998244353;

int T, m, k, ans[170], tmp[170]; long long N;
int n = 1, inv[12], id[12][12][12];

void Add(int& a, int b) { a += b, a >= p ? a -= p : 0; }

struct Matrix {
    int a[170][170];
    int* operator [] (int x) { return a[x]; }
    Matrix operator * (Matrix b) const {
        Matrix ans;
        for (int i = 1; i <= n + 1; ++i)
            for (int j = 1; j <= n + 1; ++j) ans[i][j] = 0;
        for (int i = 1; i <= n + 1; ++i)
            for (int j = 1; j <= n + 1; ++j)
                for (int k = 1; k <= n + 1; ++k)
                    Add(ans[i][j], 1ll * a[i][k] * b[k][j] % p);
        return ans;
    }
} dp[170];

void mul(Matrix a) {
    for (int i = 1; i <= n + 1; ++i) tmp[i] = 0;
    for (int i = 1; i <= n + 1; ++i)
        for (int j = 1; j <= n + 1; ++j)
            Add(tmp[i], 1ll * ans[j] * a[j][i] % p);
    for (int i = 1; i <= n + 1; ++i) ans[i] = tmp[i];
}

void power(long long N) {
    for (int i = 0; N; N >>= 1, ++i) if (N & 1) mul(dp[i]);
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
    scanf("%d %d %d", &T, &m, &k);
    inv[0] = inv[1] = 1;
    for (int i = 2; i < 10; ++i)
        inv[i] = 1ll * (p - p / i) * inv[p % i] % p;
    for (int A = 0; A <= k; ++A)
        for (int B = 0; B <= (m > 1 ? k - A : 0); ++B)
            for (int C = 0; C <= (m > 2 ? k - A - B : 0); ++C)
                id[A][B][C] = ++n;
    for (int A = 0; A <= k; ++A)
        for (int B = 0; B <= (m > 1 ? k - A : 0); ++B)
            for (int C = 0; C <= (m > 2 ? k - A - B : 0); ++C) {
                int x = id[A][B][C], y = A + B + C < k;
                if (m == 1)
                    if (A) dp[0][x][id[A - 1][B][C]] = 1ll * A * inv[A + B + C + 1] % p;
                if (m == 2) {
                    if (A) dp[0][x][id[A - 1][B][C]] = 1ll * A * inv[A + B + C + 1] % p;
                    if (B) dp[0][x][id[A + 1][B - 1 + y][C]] = 1ll * B * inv[A + B + C + 1] % p;
                }
                if (m == 3) {
                    if (A) dp[0][x][id[A - 1][B][C]] = 1ll * A * inv[A + B + C + 1] % p;
                    if (B) dp[0][x][id[A + 1][B - 1][C + y]] = 1ll * B * inv[A + B + C + 1] % p;
                    if (C) dp[0][x][id[A][B + 1][C - 1 + y]] = 1ll * C * inv[A + B + C + 1] % p;
                }
                dp[0][x][x] = dp[0][x][n + 1] = inv[A + B + C + 1];
            }
    dp[0][n + 1][n + 1] = 1;
    for (int i = 1; i <= 60; ++i) dp[i] = dp[i - 1] * dp[i - 1]; 
    while (T--) {
        scanf("%lld", &N);
        for (int i = 1; i <= n + 1; ++i) ans[i] = 0;
        if (m == 1) ans[id[1][0][0]] = 1;
        if (m == 2) ans[id[0][1][0]] = 1;
        if (m == 3) ans[id[0][0][1]] = 1;
        power(N), printf("%d\n", ans[n + 1]);
    }
    return 0;
}

推荐阅读