首页 > 技术文章 > [luogu P5824] 十二重计数法

zkdxl 2021-04-23 09:44 原文

\(\text{Problem}:\)十二重计数法

\(\text{Solution}:\)

第一重(球之间互不相同,盒子之间互不相同):

对于每个球都有 \(m\) 个盒子放,即 \(m^{n}\)

第二重(球之间互不相同,盒子之间互不相同,每个盒子至多装一个球):

\(n>m\) 时为 \(0\);当 \(n\leq m\) 时为 \(m^{\underline{n}}\)

第三重(球之间互不相同,盒子之间互不相同,每个盒子至少装一个球):

考虑容斥。求出至少有 \(k\) 个盒子为空的方案数,乘上容斥系数即可。答案为:

\[\sum\limits_{k=0}^{m}(-1)^{k}\binom{m}{k}(m-k)^{n} \]

第四重(球之间互不相同,盒子全部相同):

由于盒子都是相同的,只需枚举有多少个盒子非空即可,答案为:

\[\sum\limits_{i=0}^{m}{n\brace i} \]

第五重(球之间互不相同,盒子全部相同,每个盒子至多装一个球):

\([n\leq m]\)

第六重(球之间互不相同,盒子全部相同,每个盒子至少装一个球):

第二类斯特林数,即 \({n\brace m}\)

第七重(球全部相同,盒子之间互不相同):

利用插板法,答案为 \(\binom{n+m-1}{m-1}\)

第八重(球全部相同,盒子之间互不相同,每个盒子至多装一个球):

选择 \(n\) 个盒子放球,即 \(\binom{m}{n}\)

第九重(球全部相同,盒子之间互不相同,每个盒子至少装一个球):

球都是相同的,那么在每个盒子里先放一个球,问题转化为第七重,答案为 \(\binom{n-1}{m-1}\)

第十重(球全部相同,盒子全部相同):

\(f_{n,m}\) 表示把 \(n\) 拆分成 \(m\) 个数(无序)的方案数。首先考虑一个经典的 \(O(n^2)\)\(dp\),有:

\[f_{i,j}=f_{i,j-1}+f_{i-j,j} \]

即每操作新加入一个元素 \(0\),或将当前所有数加 \(1\)。显然这样计数是不重不漏的。

设第 \(i\) 列拆分数的 \(\text{OGF}\)\(F_{i}(x)\),有:

\[F_{i}(x)=F_{i-1}(x)+x^{i}F_{i}(x)=\cfrac{F_{i-1}(x)}{1-x^{i}}\\ \]

\(F_{0}(x)=1\),得到:

\[F_{i}(x)=\prod\limits_{j=1}^{i}\cfrac{1}{1-x^{j}} \]

\(G_{i}(x)=\dfrac{1}{1-x^{i}}\),考虑对两边取 \(\ln\),有:

\[\begin{aligned} \ln G_{i}(x)&=\int(\ln G_{i}(x))'dx\\ &=\int \cfrac{G_{i}(x)'}{G_{i}(x)}dx\\ &=\int ((1-x^{i})\sum\limits_{j=0}^{\infty}ij\cdot x^{ij-1})dx\\ &=\int (\sum\limits_{j=0}^{\infty}ij\cdot x^{ij-1}+\sum\limits_{j=1}^{\infty}i(j-1)\cdot x^{ij-1})dx\\ &=\int (\sum\limits_{j=1}^{\infty}i\cdot x^{ij-1})dx\\ &=\sum\limits_{j=1}^{\infty}\cfrac{x^{ij}}{j} \end{aligned} \]

现在回到 \(F_{i}(x)\),也是对两边取 \(\ln\),有:

\[\ln F_{i}(x)=\sum\limits_{j=1}^{i}\ln G_{j}(x)=\sum\limits_{j=1}^{i}\sum\limits_{k=1}^{\infty}\cfrac{x^{jk}}{k} \]

在模 \(x^{n+1}\) 意义下,\(\ln F_{i}(x)\) 可以在 \(O(n\log n)\) 的时间复杂度内求出,再做一个多项式 \(exp\) 即可解决。答案为 \([x^{n}]F_{m}(x)\)

第十一重(球全部相同,盒子全部相同,每个盒子至多装一个球):

\([n\leq m]\)

第十二重(球全部相同,盒子全部相同,每个盒子至少装一个球):

与第十重做法相同,\(n\rightarrow n-m\) 即可。

\(\text{Code}:\)

#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=550010, Mod=998244353; 
inline int read()
{
	int s=0, w=1; ri char ch=getchar();
	while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
	return s*w;
}
int n,m;
int rev[N],r[24][2],fac[N+5],inv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return 1ll*fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
	for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
	for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
	{
		int wn=r[cnt][type];
		for(ri int j=0,mid=(i>>1);j<T;j+=i)
		{
			for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
			{
				int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
				s[j+k]=(x+y)%Mod;
				s[j+mid+k]=x-y;
				if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
			}
		}
	}
	if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> B)
{
	int len=n+m;
	int T=1;
	while(T<=len) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	for(ri int i=n+1;i<T;i++) A[i]=0;
	for(ri int i=m+1;i<T;i++) B[i]=0;
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
	DFT(T,A,0);
}
inline void GetInv(int n,vector<int> &F,vector<int> G)
{
	if(n==1) { F[0]=ksc(G[0],Mod-2); return; }
	GetInv((n+1)/2,F,G);
	vector<int> A,B;
	int T=1;
	while(T<=n+n) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=G[i];
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=(2ll*A[i]%Mod-1ll*B[i]*A[i]%Mod*A[i]%Mod+Mod)%Mod;
	DFT(T,A,0);
	for(ri int i=0;i<n;i++) F[i]=A[i]; 
}
inline void GetDao(int n,vector<int> &A,vector<int> B)
{
	for(ri int i=0;i<n-1;i++) A[i]=1ll*(i+1)*B[i+1]%Mod;
	A[n-1]=0;
}
inline void GetJi(int n,vector<int> &A,vector<int> B)
{
	for(ri int i=1;i<n;i++) A[i]=1ll*B[i-1]*fac[i-1]%Mod*inv[i]%Mod;
	A[0]=0;
}
inline void GetLn(int n,vector<int> &F,vector<int> G)
{
	vector<int> A,B;
	A.resize(n), B.resize(n);
	GetDao(n,A,G);
	GetInv(n,B,G);
	NTT(n,n,A,B);
	GetJi(n,F,A);
}
inline void GetExp(int n,vector<int> &F,vector<int> G)
{
	if(n==1) { F[0]=1; return; }
	GetExp((n+1)/2,F,G);
	vector<int> C;
	C.resize(n);
	GetLn(n,C,F);
	vector<int> A,B;
	int T=1;
	while(T<=n+n) T<<=1;
	Get_Rev(T);
	A.resize(T), B.resize(T);
	for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=(G[i]-C[i]+Mod)%Mod; B[0]++, B[0]%=Mod;
	DFT(T,A,1), DFT(T,B,1);
	for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
	DFT(T,A,0);
	for(ri int i=0;i<n;i++) F[i]=A[i];
}
inline void Task1() { printf("%d\n",ksc(m,n)); }
inline void Task2() { printf("%d\n",(n>m)?(0):(1ll*fac[m]*inv[m-n]%Mod)); }
inline void Task3()
{
	int ans=0;
	for(ri int i=0;i<=m;i++)
	{
		int w=1ll*C(m,i)*ksc(m-i,n)%Mod;
		if(i&1) ans=(ans-w+Mod)%Mod;
		else ans=(ans+w)%Mod;
	}
	printf("%d\n",ans);
}
inline void Task4()
{
	vector<int> A,B;
	A.resize(n+1), B.resize(n+1);
	for(ri int i=0;i<=n;i++)
	{
		A[i]=((i&1)?(-1):(1))*inv[i];
		if(A[i]<0) A[i]+=Mod;
		B[i]=1ll*inv[i]*ksc(i,n)%Mod;
	}
	NTT(n+1,n+1,A,B);
	int ans=0;
	for(ri int i=0;i<=min(n,m);i++) ans=(ans+A[i])%Mod;
	printf("%d\n",ans);
}
inline void Task5() { printf("%d\n",(n<=m)); }
inline void Task6()
{
	int ans=0;
	for(ri int i=0;i<=m;i++)
	{
		int w=1ll*ksc(i,n)*inv[i]%Mod*inv[m-i]%Mod;
		if((m-i)&1) ans=(ans-w+Mod)%Mod;
		else ans=(ans+w)%Mod;
	}
	printf("%d\n",ans);
}
inline void Task7() { printf("%d\n",C(n+m-1,m-1)); }
inline void Task8() { printf("%d\n",C(m,n)); }
inline void Task9() { printf("%d\n",C(n-1,m-1)); }
vector<int> F;
inline void Task10()
{
	F.resize(n+1);
	for(ri int i=1;i<=m;i++)
	{
		for(ri int j=1;i*j<=n;j++)
		{
			F[i*j]=(F[i*j]+1ll*fac[j-1]*inv[j]%Mod)%Mod;
		}
	}
	GetExp(n+1,F,F);
	printf("%d\n",F[n]);
}
inline void Task11() { printf("%d\n",(n<=m)); }
inline void Task12() { printf("%d\n",(n<m)?(0):F[n-m]); }
signed main()
{
	r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
	for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
	fac[0]=1;
	for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
	inv[N]=ksc(fac[N],Mod-2);
	for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
	n=read(), m=read();
	Task1();
	Task2();
	Task3();
	Task4();
	Task5();
	Task6();
	Task7();
	Task8();
	Task9();
	Task10();
	Task11();
	Task12();
	return 0;
}

推荐阅读