首页 > 技术文章 > LuoguP6694 强迫症 找规律+NTT

guangheli 2020-07-29 15:56 原文

比赛的时候切的,可惜没抢到首 A.

权值和等于 $\sum_{i=1}^{n} \sum_{j=i+1}^{n} a_{i} \times a_{j} \times f(j-i+1) \times f(n-j+i+2)$,最后再除以 4.  

然后这里的 $f(n)$ 就代表 $n$ 个点构成的圆环的不相交生成子图的方案数.

可以将前 $5$ 项打表打出来,上 $OEIS$ 找到通项公式,然后就可以拿到 $40$pts 的好成绩.         

观察上面的式子,我们发现可以枚举 $f(j-i+1)$ 的长度,然后就要求有贡献的点对距离必须为 $(j-i+1)$,这个用 NTT 来求即可. 

code: 

#include <cstdio>  
#include <cstring>
#include <algorithm>
#define N 100007
#define ll long long 
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std; 
int a[N],pw[N],n,val[N],f[N],A[N<<2],B[N<<2];   
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod) {   
        if(y&1) {   
            tmp=(ll)tmp*x%mod;  
        }
    }
    return tmp;  
}  
int get_inv(int x) { 
    return qpow(x,mod-2);  
}
void NTT(int *a,int len,int op) { 
    for(int i=0,k=0;i<len;++i) {   
        if(i>k) { 
            swap(a[i],a[k]);
        }  
        for(int j=len>>1;(k^=j)<j;j>>=1);  
    }     
    for(int l=1;l<len;l<<=1) { 
        int wn=qpow(3,(mod-1)/(l<<1));  
        if(op==-1) { 
            wn=get_inv(wn); 
        }  
        for(int i=0;i<len;i+=l<<1) { 
            int w=1;  
            for(int j=0;j<l;++j) { 
                int x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                a[i+j]=(ll)(x+y)%mod;  
                a[i+j+l]=(ll)(x-y+mod)%mod;  
                w=(ll)w*wn%mod;  
            }
        }
    }
    if(op==-1) { 
        int iv=get_inv(len);  
        for(int i=0;i<len;++i) {    
            a[i]=(ll)a[i]*iv%mod;  
        }
    }
}     
void init() {
    pw[0]=1;
    for(int i=1;i<N;++i) {
        pw[i]=(ll)pw[i-1]*2%mod;
    }
    a[0]=a[1]=1;
    for(int i=2;i<N;++i) {
        a[i]=(ll)(6*i-3)*a[i-1]%mod;
        (a[i]+=(ll)(mod-(ll)(i-2)*a[i-2]%mod)%mod)%=mod;
        a[i]=(ll)a[i]*get_inv(i+1)%mod;
    }
    f[2]=2;
    for(int i=3;i<N;++i) {
        f[i]=(ll)a[i-2]*pw[i]%mod;
    }
}   
int main() {
    // setIO("input");
    init();
    scanf("%d",&n);
    for(int i=0;i<n;++i) {
        scanf("%d",&val[i]);   
    }   
    for(int i=0;i<n;++i) { 
        A[i]=val[i];  
        B[i]=val[n-1-i];  
    }
    int lim; 
    for(lim=1;lim<=2*n;lim<<=1);   
    NTT(A,lim,1),NTT(B,lim,1);  
    for(int i=0;i<lim;++i) { 
        A[i]=(ll)A[i]*B[i]%mod; 
    }
    NTT(A,lim,-1);      
    int ans=0,inv2=get_inv(4); 
    for(int i=1;i<n;++i) { 
        int x=i+1;  
        int y=n-x+2;   
        (ans+=(ll)A[i+n-1]*f[x]%mod*f[y]%mod*inv2%mod)%=mod; 
    } 
    printf("%d\n",(ll)ans*get_inv(f[n])%mod);
    return 0;
}

  

推荐阅读