首页 > 技术文章 > 51nod1538:一道难题(常系数线性递推/Cayley-Hamilton定理)

cjoieryl 2018-12-20 22:18 原文

传送门

Sol

考虑要求的东西的组合意义,问题转化为:
\(n\) 种小球,每种的大小为 \(a_i\),求选出大小总和为 \(m\) 的小球排成一排的排列数
有递推 \(f_i=\sum_{j=1}^{n}f_{i-a_j}\)

常系数线性递推

求一个满足 \(k\) 阶齐次线性递推数列 \(f_i\) 的第 \(n\)

\[f_n=\sum\limits_{i=1}^{k}a_i \times f_{n-i} \]

给出 \(a_1...a_k\) 以及 \(f_0\)
\(k\)\(10^5\) 级别,\(n\le 10^{18}\)
它的特征多项式为

\[C(x)=x^k-\sum_{i=1}^{k}a_ix^{k-i} \]

如果 \(n\) 不是很大,可以直接对于 \(C(x)\) 求逆得到 \(f_1...f_n\)
否则
设向量 \(\alpha_i=(f_i,f_{i+1},...,f_{i+k-1})\)
\(f_i\) 的转移矩阵为 \(M\)
那么 \(\alpha_0M^n=\alpha_n\)
引入Cayley-Hamilton定理
\(M\) 看成 \(x\) 带入 \(P(x)\) 中,有 \(P(M)=0\) (全 \(0\) 矩阵)
所以有 \(\alpha_0M^n\equiv \alpha_n(mod~P(M))\)

如何求 \(P(x)\)
显然P(x)=C(x)
\(M\) 写出来

\[M=\begin{pmatrix} 0 & 0 & 0 & \cdots & 0 & 0 & a_{k} \\ 1 & 0 & 0 & \cdots & 0 & 0 & a_{k-1} \\ 0 & 1 & 0 & \cdots & 0 & 0 & a_{k-2} \\ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots \\ 0 & 0 & 0 & \cdots & 0 & 0 & a_{3} \\ 0 & 0 & 0 & \cdots & 1 & 0 & a_{2} \\ 0 & 0 & 0 & \cdots & 0 & 1 & a_{1} \end{pmatrix} \]

根据定义 \(P(x)=|xI-M|\)\(I\) 为单位矩阵
那么

\[xI-M=\begin{pmatrix} x & 0 & 0 & \cdots & 0 & 0 & -a_{k} \\ -1 & x & 0 & \cdots & 0 & 0 & -a_{k-1} \\ 0 & -1 & x & \cdots & 0 & 0 & -a_{k-2} \\ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots \\ 0 & 0 & 0 & \cdots & x & 0 & -a_{3} \\ 0 & 0 & 0 & \cdots & -1 & x & -a_{2} \\ 0 & 0 & 0 & \cdots & 0 & -1 & x-a_{1} \end{pmatrix} \]

按照最后一列展开得到
\(P(x)=x^k-a_1x^{k-1}-a_2x^{k-2}-\cdots-a_k=C(x)\)
所以只要多项式倍增快速幂 \(+\) 取模就好了(听起来就慢)
最后

\[\alpha_n=\alpha_0\sum_{i=0}^{k-1}g_iM^i=\sum_{i=0}^{k-1}g_i\alpha_i \]

所以有

\[f_n=\sum_{i=0}^{k-1}g_if_i \]

\(\Theta(k)\) 计算即可
注意多项式倍增快速幂的时候取模的多项式是一样的,可以预处理出它的逆

解决

套常系数线性递推的方法即可,前面 \(23333\) 项可以预处理求逆得到

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int mod(104857601);
const int maxn(1 << 16);

inline int Pow(ll x, int y) {
	register ll ret = 1;
	for (; y; y >>= 1, x = x * x % mod)
		if (y & 1) ret = ret * x % mod;
	return ret;
}

inline void Inc(int &x, const int y) {
	if ((x += y) >= mod) x -= mod;
}

namespace FFT {
	int a[maxn], b[maxn], len, r[maxn], l, w[2][maxn];
	
	inline void Init(const int n) {
		register int i, x, y;
		for (l = 0, len = 1; len < n; len <<= 1) ++l;
		for (i = 0; i < len; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
		for (i = 0; i < len; ++i) a[i] = b[i] = 0;
		w[1][0] = w[0][0] = 1, x = Pow(3, (mod - 1) / len), y = Pow(x, mod - 2);
		for (i = 1; i < len; ++i) w[0][i] = (ll)w[0][i - 1] * x % mod, w[1][i] = (ll)w[1][i - 1] * y % mod;
	}

	inline void NTT(int *p, const int opt) {
		register int i, j, k, wn, t, x, y;
		for (i = 0; i < len; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
		for (i = 1; i < len; i <<= 1)
			for (t = i << 1, j = 0; j < len; j += t)
				for (k = 0; k < i; ++k) {
					wn = w[opt == -1][len / t * k];
					x = p[j + k], y = (ll)wn * p[i + j + k] % mod;
					p[j + k] = x + y >= mod ? x + y - mod : x + y;
					p[i + j + k] = x - y < 0 ? x - y + mod : x - y;
				}
		if (opt == -1) for (wn = Pow(len, mod - 2), i = 0; i < len; ++i) p[i] = (ll)p[i] * wn % mod;
	}

	inline void Calc1() {
		register int i;
		NTT(a, 1), NTT(b, 1);
		for (i = 0; i < len; ++i) a[i] = (ll)a[i] * b[i] % mod;
		NTT(a, -1);
	}

	inline void Calc2() {
		register int i;
		NTT(a, 1), NTT(b, 1);
		for (i = 0; i < len; ++i) a[i] = (ll)a[i] * b[i] % mod * b[i] % mod;
		NTT(a, -1);
	}
}

struct Poly {
	vector <int> v;

	inline Poly() {
		v.resize(1);
	}

	inline Poly(const int d) {
		v.resize(d);
	}

	inline int Length() const {
		return v.size();
	}

	inline Poly operator +(Poly b) const {
		register int i, l1 = Length(), l2 = b.Length(), l3 = max(l1, l2);
		register Poly c(l3);
		for (i = 0; i < l1; ++i) c.v[i] = v[i];
		for (i = 0; i < l2; ++i) Inc(c.v[i], b.v[i]);
		return c;
	}

	inline Poly operator -(Poly b) const {
		register int i, l1 = Length(), l2 = b.Length(), l3 = max(l1, l2);
		register Poly c(l3);
		for (i = 0; i < l1; ++i) c.v[i] = v[i];
		for (i = 0; i < l2; ++i) Inc(c.v[i], mod - b.v[i]);
		return c;
	}

	inline void InvMul(Poly b) {
		register int i, l1 = Length(), l2 = b.Length(), l3 = l1 + l2 - 1;
		FFT :: Init(l3);
		for (i = 0; i < l1; ++i) FFT :: a[i] = v[i];
		for (i = 0; i < l2; ++i) FFT :: b[i] = b.v[i];
		FFT :: Calc2();
	}

	inline Poly operator *(Poly b) const {
		register int i, l1 = Length(), l2 = b.Length(), l3 = l1 + l2 - 1;
		register Poly c(l3);
		FFT :: Init(l3);
		for (i = 0; i < l1; ++i) FFT :: a[i] = v[i];
		for (i = 0; i < l2; ++i) FFT :: b[i] = b.v[i];
		FFT :: Calc1();
		for (i = 0; i < l3; ++i) c.v[i] = FFT :: a[i];
		return c;
	}

	inline Poly operator *(int b) const {
		register int i, l = Length();
		register Poly c(l);
		for (i = 0; i < l; ++i) c.v[i] = (ll)v[i] * b % mod;
		return c;
	}

	inline int Calc(const int x) {
		register int i, ret = v[0], l = Length(), now = x;
		for (i = 1; i < l; ++i) Inc(ret, (ll)now * v[i] % mod), now = (ll)now * x % mod;
		return ret;
	}
};

inline void Inv(Poly p, Poly &q, int len) {
	if (len == 1) {
		q.v[0] = Pow(p.v[0], mod - 2);
		return;
	}
	Inv(p, q, len >> 1);
	register int i;
	p.InvMul(q);
	for (i = 0; i < len; ++i) q.v[i] = ((ll)2 * q.v[i] + mod - FFT :: a[i]) % mod;
}

inline Poly Inverse(Poly a) {
	register int n = a.Length(), len;
	for (len = 1; len < n; len <<= 1);
	register Poly c(len);
	Inv(a, c, len), c.v.resize(a.Length());
	return c;
}

Poly invc;

inline Poly operator %(const Poly &a, const Poly &b) {
	if (a.Length() < b.Length()) return a;
	register Poly x = a, y = invc;
	register int n = a.Length(), m = b.Length(), res = n - m + 1;
	reverse(x.v.begin(), x.v.end());
	x.v.resize(res), y.v.resize(res), x = x * y, x.v.resize(res);
	reverse(x.v.begin(), x.v.end()), y = a - x * b, y.v.resize(m - 1);
	return y;
}

int n, k = 23333, f[maxn], a[maxn], ans;
ll m;
Poly c, trs, tmp;

int main() {
	register int i, sa, sb, v;
	scanf("%d%lld%d%d%d", &n, &m, &v, &sa, &sb), sa %= k, sb %= k;
	for (++a[v], i = 2; i <= n; ++i) ++a[v = (v * sa + sb) % k + 1];
	c.v.resize(k + 1), c.v[k] = 1;
	for (i = 1; i <= k; ++i) c.v[k - i] = (mod - a[i]) % mod;
	tmp.v.resize(2), trs.v.resize(1), trs.v[0] = tmp.v[1] = 1;
	invc = c, reverse(invc.v.begin(), invc.v.end()), invc = Inverse(invc);
	for (i = 0; i < k; ++i) f[i] = invc.v[i];
	if (m < k) return printf("%d\n", f[m]), 0;
	for (; m; m >>= 1, tmp = tmp * tmp % c) if (m & 1) trs = trs * tmp % c;
	for (i = 0; i < k; ++i) Inc(ans, (ll)trs.v[i] * f[i] % mod);
	printf("%d\n", ans);
    return 0;
}

推荐阅读