首页 > 技术文章 > 排列计数机

Ax-Dea 2021-07-18 16:41 原文

牛客 - 排列计数机

statement

定义一个长为 \(k\) 的序列 \(A_1,A_2,…,A_k\) 的权值为:对于所有 \(1≤i≤k\)\(\max(A_1,A_2,…,A_i)\) 有多少种不同的取值。

给出一个 \(1\)\(n\) 的排列 \(B_1,B_2,…,B_n\),求 \(B\) 的所有非空子序列的权值的 \(m\) 次方之和。

答案对 \(10^9+7\) 取模。

Hints

对于前 \(10\%\) 的数据,\(n≤20\)
对于前 \(20\%\) 的数据,\(n≤100\)
对于前 \(40\%\) 的数据,\(n≤1000\)
对于另外 \(20\%\) 的数据,\(m=1\)
对于所有数据,\(1≤n≤10^5\)\(1≤m≤20\),保证 \(B\)\(1\)\(n\) 的排列。

solution

先想一个暴力 DP,设 \(dp(i, j, k)\) 表示到前 \(i\) 个数,并且第 \(i\) 个数字必须选,最大值为 \(j\),权值为 \(k\) 的方案数,转移:

\[dp(i, j, k) = \sum_{l < i \and p\ge a_i } dp(l, p, k) + \sum_{l < i \and p < a_i} dp(l, p, k - 1) \]

可以做到 \(\mathcal O (n ^ 4) - \mathcal O (n ^ 3)\)

可以发现转移的时候,对于同一个二元组 \((j, k)\) 对应的每一个 \(i\) 总是会在先前的状态原封不动地转移过来,所以可以把原先的第一维度压掉,最后的转移只有:

\[\begin{align} &dp(i, j) \rightarrow dp(A_k, j + 1) & A_k > i\\ &dp(i, j) \rightarrow dp(i, j) & A_k \le i \end{align} \]

其中 \(A_k\) 是当前枚举到的数字。

由于这个 \(m\) 很小,尝试展开这个次幂的式子,首先知道二项式定理:

\[(a + b) ^ m =\sum_{k \ge 0} {m\choose k} a ^ k b ^ {m - k} \]

这启发我们维护每个最大值对应的每个次幂的信息,由于 \(m\) 比较小,看起来非常赚。

\(f(i, j)\) 表示最大值为 \(i\) ,子序列个数的 \(j\) 次幂的大小,看下上面 \(dp\) 的转移。

  • \(dp(i, j) \leftarrow dp(i, j) \ A_k \le i\)

等价于 \(f(i, j) \leftarrow f(i, j) \ A_k \le i\)

  • \(dp(A_k, j) \leftarrow dp(i, j)\ A_k > i\)

等价于 \(f(A_{now}, j) = \sum_i \sum_{k\le j} {j\choose k} f(i, k) \ A_k > i\)

这个 \(f\) 可以用线段树简单维护一下,支持区间乘二,区间求和,区间赋值即可。

时间复杂度 \(\mathcal O (nm^2 + nm\log n)\)

#include <bits/stdc++.h>
#define forn(i,s,t) for(register int i=(s); i<=(t); ++i)
#define form(i,s,t) for(register int i=(s); i>=(t); --i)
#define rep(i,s,t) for(register int i=(s); i<(t); ++i)
using namespace std;
const int N = 1e5 + 3, M = 22, Mod = 1e9 + 7;
struct Mint {
	int res;
	Mint() {}
	Mint(int a) : res(a) {}
	inline friend Mint operator + (Mint A, Mint B) {
		return Mint((A.res + B.res >= Mod) ? (A.res + B.res - Mod) : (A.res + B.res));
	}
	inline friend Mint operator - (Mint A, Mint B) {return A + Mint(Mod - B.res);}
	inline friend Mint operator * (Mint A, Mint B) {return Mint(1ll * A.res * B.res %Mod);}
	inline friend Mint& operator += (Mint& A, Mint B) {return A = A + B;}
	inline friend Mint& operator -= (Mint& A, Mint B) {return A = A - B;}
	inline friend Mint& operator *= (Mint& A, Mint B) {return A = A * B;}
	inline friend Mint q_pow(Mint p, int k = Mod - 2) {
		static Mint res; res = Mint(1);
		for(; k; k >>= 1, p *= p) (k & 1) && (res *= p, 0);
		return res;
	}
	inline friend Mint operator ~ (Mint A) {return q_pow(A);}
};
Mint fac[N], ifac[N];
inline void init(int n) {
	fac[0] = Mint(1);
	forn(i,1,n) fac[i] = fac[i - 1] * Mint(i);
	ifac[n] = ~fac[n];
	form(i,n - 1,0) ifac[i] = ifac[i + 1] * Mint(i + 1);
}
inline Mint C(int n, int r) {return fac[n] * ifac[r] * ifac[n - r];}
Mint pow2[N], F[N][M]; int m;
struct node {
	int tag; Mint cof[M];
};
struct SegTree {
	node val[N << 2];
	inline void up(int p) {forn(i,0,m) val[p].cof[i] = val[p << 1].cof[i] + val[p << 1 | 1].cof[i];}
	inline void opt(int p, int k) {
		val[p].tag += k; forn(i,0,m) val[p].cof[i] *= pow2[k];
	}
	inline void down(int p) {
		opt(p << 1, val[p].tag), opt(p << 1 | 1, val[p].tag), val[p].tag = 0;
	}
	void Upd(int p, int l, int r, int nl, int nr) {
		if(l == nl && nr == r) return opt(p, 1);
		int mid = nl+nr >> 1;
		(val[p].tag) && (down(p), 0);
		if(r <= mid) Upd(p << 1, l, r, nl, mid);
		else if(l > mid) Upd(p << 1 | 1, l, r, mid + 1, nr);
		else Upd(p << 1, l, mid, nl, mid), Upd(p << 1 | 1, mid + 1, r, mid + 1, nr);
		up(p);
	}
	node Qry(int p, int l, int r, int nl, int nr) {
		if(l == nl && nr == r) return val[p];
		int mid = nl+nr >> 1;
		(val[p].tag) && (down(p), 0);
		if(r <= mid) return Qry(p << 1, l, r, nl, mid);
		else if(l > mid) return Qry(p << 1 | 1, l, r, mid + 1, nr);
		else {
			node res;
			node L = Qry(p << 1, l, mid, nl, mid);
			node R = Qry(p << 1 | 1, mid + 1, r, mid + 1, nr);
			forn(i,0,m) res.cof[i] = L.cof[i] + R.cof[i];
			return res;
		}
	}
	void Cov(int p, int l, int r, int pos) {
		if(l == r) {
			forn(i,0,m) val[p].cof[i] = F[pos][i];
			return ;
		}
		int mid = l+r >> 1;
		(val[p].tag) && (down(p), 0);
		if(pos <= mid) Cov(p << 1, l, mid, pos);
		else Cov(p << 1 | 1, mid + 1, r, pos);
		up(p);
	}
}T;
int n, a[N];
int main() {
	scanf("%d%d", &n, &m), init(m);
	pow2[0] = Mint(1);
	forn(i,1,n) scanf("%d", a + i), pow2[i] = pow2[i - 1] + pow2[i - 1];
	forn(i,1,n) {
		static node res;
		T.Upd(1, a[i], n, 1, n), res = T.Qry(1, 1, a[i], 1, n);
		forn(j,0,m) {
			F[a[i]][j] = Mint(1);
			forn(k,0,j) F[a[i]][j] += C(j, k) * res.cof[k];
		}
		T.Cov(1, 1, n, a[i]);
	}
	printf("%d", T.val[1].cof[m]);
	return 0;
}

推荐阅读