首页 > 技术文章 > XJOI 3866 写什么名字好呢

BlogOfchc1234567890 2018-10-27 20:33 原文

题意

给你一个数组\(R\),包含\(N\)个元素,求有多少满足条件的序列A使得

\[0≤A[i]≤R[i] \]

\[A[0]+A[1]+...+A[N−1]=A[0]or A[1]...or A[N−1] \]

输出答案对\(1e9+9\)取模

输入格式

第一行输入一个整数\(N (2≤N≤10)\)
第二行输入\(N\) 个整数 \(R[i] (1≤R[i]≤1e18)\)

输出格式

输出一个整数

样例输入&输出

样例1

2
3 5

15

样例2

3
3 3 3

16

样例3

2
1 128

194

样例4

4
26 74 25 30
 
8409

样例5

2
1000000000 1000000000
 
420352509

分析

这道题目一定有很大价值。我想了一天
首先,暴力dfs肯定不行,看到\(0≤A[i]≤R[i]\),联想到最近学的数位dp。

如果不会数位dp或不太熟练,建议先学一学或巩固一下,否则接下来可能会不能理解

由$$A[0]+A[1]+...+A[N−1]=A[0]or A[1]...or A[N−1]$$这个条件,我们容易推出它的充分条件是

\(A_i\) 的二进制第 \(j\) 位为\(A_{i,j}\) ,则对于某一位 \(j\) ,有\(\sum_{i=1}^{n} A_{i,j}=1\)\(0\) ,即在第 \(j\) 位上,只存在一个 \(A_i\) 或不存在任何一个 \(A_i\) ,使得 \(A_{i,j}\)\(1\)

由此我们dp,设 \(dp[j][i][k](i≠0)\) 表示在第 \(j\) 位上,\(A_{i,j}\) 都等于 \(1\),其余的\(A_{i,j'}(j' \in [1,n],j' ≠0)\) 都为 \(0\)\(k\) 为状态压缩,即用二进制压掉; \(dp[j][0][k]\) 表示在第 \(j\) 位上,所有的 \(A_{i,j}\) 都为 \(0\)

上面为什么要状压?因为在数位dp中,对于当前dfs到的数要用一个limit记录它是否受最高位限制,如果对于每一个 \(A_i\) 都开一个limit,那代码会变得十分冗杂。用二进制不仅能简化代码,还能优化掉空间复杂度\(\frac{1}{8}\)的常数。由于 \(n\le 10\) ,所以初始化limit变量为1023即可。别忘了把dp数组初始化为 \(-1\)

Code

#include<cstdio>
#define maxn 12
#define maxw 70
#define mod 1000000009 //别写错
#define get(x,pos) bool((x)&(1ll<<pos)) //返回x的二进制位的第pos+1个数字
#define set0(x,pos) ((x)&(1023-(1<<pos))) //把x的第pos+1个二进制位设为0
using namespace std;
int lg(long long x){ //log2(懒得调用cmath)
	int ret=-1;
	while(x){
		ret++;
		x>>=1;
	}
	return ret;
}
long long a[maxn],dp[maxw][maxn][1030],ma; //记得开long long
int n;
long long dfs(int pos,int limit){ //数位dp的记搜写法(循环写法我不太会)
	if(pos<0){
		return 1;
	}
	long long ret=0,ans=0;
	int up,tmp=limit;
	for(int i=1;i<=n;i++){
		if(dp[pos][i][limit]!=-1){
			ret=(ret+dp[pos][i][limit])%mod;
		}
		else{
			up=get(limit,i-1)?get(a[i],pos):1;
			ans=0;
			if(up){
				for(int j=1;j<=n;j++){
					if(j!=i){
						if(!(get(limit,j-1)&&!get(a[j],pos)))limit=set0(limit,j-1);
					}
				}
				ans=dfs(pos-1,limit);
				limit=tmp;
			}
			dp[pos][i][limit]=ans;
			ret=(ret+ans)%mod;
		}
	}
	if(dp[pos][0][limit]!=-1){
		ret=(ret+dp[pos][0][limit])%mod;
	}
	else{
		for(int j=1;j<=n;j++){
			if(!(get(limit,j-1)&&!get(a[j],pos)))limit=set0(limit,j-1);
		}
		ans=dfs(pos-1,limit);
		limit=tmp;
		dp[pos][0][limit]=ans;
		ret=(ret+ans)%mod;
	}
	return ret;
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%lld",&a[i]);
		if(a[i]>ma)ma=a[i];
	}
	ma=lg(ma);
	for(int i=0;i<=ma;i++){
		for(int j=0;j<=n;j++){
			for(int k=0;k<1024;k++){
				dp[i][j][k]=-1;
			}
		}
	}
	printf("%lld\n",dfs(ma,1023));
	return 0;
}

推荐阅读