首页 > 技术文章 > 省选前多项式的挣扎

SperanzaLeaf 2019-03-20 15:34 原文

RT,肯定是学不完了,但是还是要挣扎一下

一个沙茶的挣扎还有什么可看的鸭=。=

废话结束


所有多项式(当然还有生成函数)相关的题都不是凭空冒出来的,一定是有一个问题,然后用它当做工具去优化问题的某个步骤

“基本工具”

FFT 

板子,没啥可说的,要精度记得预处理所有单位根(但是会很慢),不追求精度的时候预处理log个sin和cos先算即可(还有取整的问题)

精度高

 1 //Exact FFT
 2 #include<cmath>
 3 #include<cstdio>
 4 #include<cctype>
 5 #include<cstring>
 6 #include<algorithm>
 7 using namespace std;
 8 const int N=4e6+6,M=30;
 9 const double pai=acos(-1);
10 struct cpx
11 {
12     double x,y;
13 }a[N],b[N],ort[N];
14 cpx operator + (cpx a,cpx b)
15 {
16     return (cpx){a.x+b.x,a.y+b.y};
17 } 
18 cpx operator - (cpx a,cpx b)
19 {
20     return (cpx){a.x-b.x,a.y-b.y};
21 }
22 cpx operator * (cpx a,cpx b)
23 {
24     double x1=a.x,x2=b.x,y1=a.y,y2=b.y;
25     return (cpx){x1*x2-y1*y2,x1*y2+x2*y1};
26 }
27 cpx operator ! (cpx a)
28 {
29     a.y=-a.y; return a;
30 }
31 int n,m,rev[N];
32 double Sin[M],Cos[M];
33 void Read(double &x)
34 {
35     int ret=0; x=0;
36     char ch=getchar();
37     while(!isdigit(ch))
38         ch=getchar();
39     while(isdigit(ch))
40         ret=(ret<<1)+(ret<<3)+(ch^48),ch=getchar();
41     x=ret; return ;
42 }
43 void Write(int x)
44 {
45     if(x>9) Write(x/10);
46     putchar(x%10^48);
47 }
48 int Round(double x)
49 { 
50     if(fabs(x)<0.4) return 0;
51     return x>0?(int)(x+0.5):(int)(x-0.5);
52 }
53 void Prework()
54 {
55     register int i;
56     scanf("%d%d",&n,&m);
57     for(i=0;i<=n;i++) Read(a[i].x);
58     for(i=0;i<=m;i++) Read(b[i].x);
59     m+=n,n=1; while(n<=m) n<<=1;
60     for(i=0;i<n;i++)
61     {
62         rev[i]=(rev[i>>1]>>1)+(i&1)*(n>>1);
63         ort[i]=(cpx){cos(pai*i/n),sin(pai*i/n)};
64     }
65 }
66 void Trans(cpx *cop,int len,int typ)
67 {
68     register int i,j,k;
69     for(i=0;i<len;i++)
70         if(rev[i]>i) swap(cop[i],cop[rev[i]]);
71     for(i=2;i<=len;i<<=1)
72     {
73         int lth=i>>1;
74         for(j=0;j<len;j+=i)
75         {
76             cpx *org=ort;
77             for(k=j;k<j+lth;k++,org+=len/lth)
78             {
79                 cpx tmp=*org; if(typ==-1) tmp=!tmp;
80                 tmp=tmp*cop[k+lth],cop[k+lth]=cop[k]-tmp,cop[k]=cop[k]+tmp;
81             }
82         }
83     }
84     if(typ==-1)
85         for(int i=0;i<=len;i++) cop[i].x/=len;
86 }
87 int main()
88 {
89     register int i;
90     Prework();
91     Trans(a,n,1),Trans(b,n,1);
92     for(i=0;i<n;i++) a[i]=a[i]*b[i];
93     Trans(a,n,-1);
94     for(i=0;i<=m;i++) Write(Round(a[i].x)),putchar(' ');
95     return 0;
96 }
View Code

跑得快

 1 //Really Fast FFT
 2 #include<cmath>
 3 #include<cstdio>
 4 #include<cctype>
 5 #include<cstring>
 6 #include<algorithm>
 7 using namespace std;
 8 const int N=4e6+6,M=30;
 9 const double pai=acos(-1);
10 struct cpx
11 {
12     double x,y;
13 }a[N],b[N];
14 cpx operator + (cpx a,cpx b)
15 {
16     return (cpx){a.x+b.x,a.y+b.y};
17 } 
18 cpx operator - (cpx a,cpx b)
19 {
20     return (cpx){a.x-b.x,a.y-b.y};
21 }
22 cpx operator * (cpx a,cpx b)
23 {
24     double x1=a.x,x2=b.x,y1=a.y,y2=b.y;
25     return (cpx){x1*x2-y1*y2,x1*y2+x2*y1};
26 }
27 int n,m,rev[N];
28 double Sin[M],Cos[M];
29 void Read(double &x)
30 {
31     int ret=0; x=0;
32     char ch=getchar();
33     while(!isdigit(ch))
34         ch=getchar();
35     while(isdigit(ch))
36         ret=(ret<<1)+(ret<<3)+(ch^48),ch=getchar();
37     x=ret; return ;
38 }
39 void Write(int x)
40 {
41     if(x>9) Write(x/10);
42     putchar(x%10^48);
43 }
44 int Round(double x)
45 { 
46     if(fabs(x)<0.4) return 0;
47     return x>0?(int)(x+0.5):(int)(x-0.5);
48 }
49 void Prework()
50 {
51     register int i;
52     scanf("%d%d",&n,&m);
53     for(i=0;i<=n;i++) Read(a[i].x);
54     for(i=0;i<=m;i++) Read(b[i].x);
55     m+=n,n=1; while(n<=m) n<<=1;
56     for(i=0;i<n;i++)
57         rev[i]=(rev[i>>1]>>1)+(i&1)*(n>>1);
58     for(i=1;i<=24;i++)
59         Sin[i]=sin(2*pai/(1<<i)),Cos[i]=cos(2*pai/(1<<i));
60 }
61 void Trans(cpx *cop,int len,int typ)
62 {
63     register int i,j,k;
64     for(i=0;i<len;i++)
65         if(rev[i]>i) swap(cop[i],cop[rev[i]]);
66     for(i=2;i<=len;i<<=1)
67     {
68         int lth=i>>1,lgg=log2(i);
69         cpx omg={Cos[lgg],Sin[lgg]*typ};
70         for(j=0;j<len;j+=i)
71         {
72             cpx ori={1,0},tmp;
73             for(k=j;k<j+lth;k++,ori=ori*omg)
74                 tmp=ori*cop[k+lth],cop[k+lth]=cop[k]-tmp,cop[k]=cop[k]+tmp;
75         }
76     }
77     if(typ==-1)
78         for(int i=0;i<=len;i++) cop[i].x/=len;
79 }
80 int main()
81 {
82     register int i;
83     Prework();
84     Trans(a,n,1),Trans(b,n,1);
85     for(i=0;i<n;i++) a[i]=a[i]*b[i];
86     Trans(a,n,-1);
87     for(i=0;i<=m;i++) Write(Round(a[i].x)),putchar(' ');
88     return 0;
89 }
View Code

NTT

板子,没啥可说的,xehoth大佬那个频率抽取NTT在洛谷上可以总时限跑进1s,但我可能没机会学了=。=

 1 //Simple NTT
 2 #include<cmath>
 3 #include<cstdio>
 4 #include<cctype>
 5 #include<cstring>
 6 #include<algorithm>
 7 using namespace std;
 8 const int N=4000006,mod=998244353;
 9 int n,m,G,Gi,Ni,rev[N],a[N],b[N],pw[30][2];
10 inline void Read(int &x)
11 {
12     x=0;
13     char ch=getchar();
14     while(!isdigit(ch))
15         ch=getchar();
16     while(isdigit(ch))
17         x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
18     return ;
19 }
20 void Write(int x)
21 {
22     if(x>9) Write(x/10);
23     putchar(x%10^48);
24 }
25 int Qpow(int x,int k)
26 {
27     if(k==1) return x;
28     int tmp=Qpow(x,k/2);
29     return k%2?1ll*tmp*tmp%mod*x%mod:1ll*tmp*tmp%mod;
30 }
31 void Prework()
32 {
33     register int i; 
34     Read(n),Read(m);
35     for(i=0;i<=n;i++) Read(a[i]);
36     for(i=0;i<=m;i++) Read(b[i]);
37     m+=n,n=1; while(n<=m) n<<=1;
38     for(i=1;i<n;i++)
39         rev[i]=(rev[i>>1]>>1)+(i&1)*(n>>1);
40     G=3,Gi=Qpow(G,mod-2),Ni=Qpow(n,mod-2);
41     for(int i=1;i<=24;i++)
42     {
43         pw[i][0]=Qpow(G,(mod-1)/(1<<i));
44         pw[i][1]=Qpow(Gi,(mod-1)/(1<<i));
45     }
46 }
47 void Trans(int *arr,int len,int typ)
48 {
49     register int i,j,k;
50     for(i=0;i<len;i++)
51         if(rev[i]>i) swap(arr[rev[i]],arr[i]);
52     for(i=2;i<=len;i<<=1)
53     {
54         int lth=i>>1,ort=pw[(int)log2(i)][typ==-1];
55         for(j=0;j<len;j+=i)
56         {
57             int ori=1,tmp;
58             for(k=j;k<j+lth;k++,ori=1ll*ori*ort%mod)
59             {
60                 tmp=1ll*ori*arr[k+lth]%mod;
61                 arr[k+lth]=(arr[k]-tmp+mod)%mod;
62                 arr[k]=(arr[k]+tmp)%mod;
63             }
64         }
65     }
66     if(typ==-1)
67         for(i=0;i<=len;i++)
68             arr[i]=1ll*arr[i]*Ni%mod;
69 }
70 int main()
71 {
72     register int i;
73     Prework();
74     Trans(a,n,1),Trans(b,n,1);
75     for(i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
76     Trans(a,n,-1);
77     for(i=0;i<=m;i++) Write(a[i]),putchar(' ');
78     return 0;
79 }
洛谷板子

拆系数FFT

各方面吊打任意模数NTT,然而难写的一批

  1 //Coefficients Decomposing FFT
  2 #include<cmath>
  3 #include<cstdio>
  4 #include<cctype>
  5 #include<cstring>
  6 #include<algorithm>
  7 using namespace std;
  8 const int N=4e6+40,M=30;
  9 const int Pow=15,Bas=(1<<Pow)-1;
 10 const double pai=acos(-1);
 11 struct cpx
 12 {
 13     double x,y;
 14     void Turn(int a,int b)
 15     {
 16         x=a,y=b;
 17     }
 18 }a[N],b[N],c[N],d[N];
 19 const cpx b1=(cpx){0.5,0};
 20 const cpx b2=(cpx){0,-0.5};
 21 const cpx b3=(cpx){1,0};
 22 const cpx b4=(cpx){0,1};
 23 cpx operator + (cpx a,cpx b)
 24 {
 25     return (cpx){a.x+b.x,a.y+b.y};
 26 } 
 27 cpx operator - (cpx a,cpx b)
 28 {
 29     return (cpx){a.x-b.x,a.y-b.y};
 30 }
 31 cpx operator * (cpx a,cpx b)
 32 {
 33     double x1=a.x,x2=b.x,y1=a.y,y2=b.y;
 34     return (cpx){x1*x2-y1*y2,x1*y2+x2*y1};
 35 }
 36 cpx operator ! (cpx a)
 37 {
 38     a.y=-a.y; return a;
 39 }
 40 double Sin[M],Cos[M];
 41 int n,m,mod,rev[N],xx[N],yy[N],ans[N];
 42 void Read(int &x)
 43 {
 44     x=0; char ch=getchar();
 45     while(!isdigit(ch))
 46         ch=getchar();
 47     while(isdigit(ch))
 48         x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
 49     x%=mod; return;
 50 }
 51 void Write(int x)
 52 {
 53     if(x>9) Write(x/10);
 54     putchar(x%10|48);
 55 }
 56 int Roumod(double x)
 57 { 
 58     return (long long)(x+0.5)%mod;
 59 }
 60 void Prework()
 61 {
 62     register int i;
 63     scanf("%d%d%d",&n,&m,&mod);
 64     for(i=0;i<=n;i++) Read(xx[i]);
 65     for(i=0;i<=m;i++) Read(yy[i]);
 66     m+=n,n=1; while(n<=m) n<<=1;
 67     for(i=0;i<n;i++)
 68         rev[i]=(rev[i>>1]>>1)+(i&1)*(n>>1);
 69     for(i=1;i<=24;i++)
 70         Sin[i]=sin(2*pai/(1<<i)),Cos[i]=cos(2*pai/(1<<i));
 71 }
 72 void Trans(cpx *cop,int len,int typ)
 73 {
 74     register int i,j,k;
 75     for(i=0;i<len;i++)
 76         if(rev[i]>i) swap(cop[i],cop[rev[i]]);
 77     for(i=2;i<=len;i<<=1)
 78     {
 79         int lth=i>>1,lgg=log2(i);
 80         cpx omg={Cos[lgg],Sin[lgg]*typ};
 81         for(j=0;j<len;j+=i)
 82         {
 83             cpx ori=b3,tmp;
 84             for(k=j;k<j+lth;k++,ori=ori*omg)
 85                 tmp=ori*cop[k+lth],cop[k+lth]=cop[k]-tmp,cop[k]=cop[k]+tmp;
 86         }
 87     }
 88     if(typ==-1)
 89         for(int i=0;i<=len;i++)     
 90             cop[i].x/=len,cop[i].y/=len;
 91 }
 92 void Mul(cpx *c1,cpx *c2,cpx &a1,cpx &a2,int p,int q)
 93 {
 94     cpx t1=(c1[p]+!c1[q])*b1,t2=(c1[p]-!c1[q])*b2;
 95     cpx t3=(c2[p]+!c2[q])*b1,t4=(c2[p]-!c2[q])*b2;
 96     a1=t1*t3+(t1*t4+t2*t3)*b4,a2=t2*t4;
 97 }
 98 void CDFFT(int *p1,int *p2,int len)
 99 {
100     register int i;
101     for(i=0;i<len;i++)
102     {
103         a[i].Turn(p1[i]&Bas,p1[i]>>Pow);
104         b[i].Turn(p2[i]&Bas,p2[i]>>Pow);
105     }
106     Trans(a,len,1),Trans(b,len,1),Mul(a,b,c[0],d[0],0,0);
107     for(i=1;i<n;i++) Mul(a,b,c[i],d[i],i,n-i);
108     Trans(c,len,-1),Trans(d,len,-1);
109     for(i=0;i<len;i++) 
110     {
111         long long x1=Roumod(c[i].x),y1=Roumod(c[i].y),x2=Roumod(d[i].x);
112         ans[i]=(((x2<<(Pow<<1))+(y1<<Pow)+x1)%mod+mod)%mod;
113     }
114 }
115 int main()
116 {
117     register int i;
118     Prework(),CDFFT(xx,yy,n);
119     for(i=0;i<=m;i++) Write(ans[i]),putchar(' ');
120     return 0;
121 }
View Code

“升级过的工具”

多项式求逆

倍增的思想,设原来的多项式是f,我们现在已知的逆元是g,要求下一级逆元g'

那么$g'=2g-fg^2$,边界就是常数项的逆元

所以一个多项式有没有逆元取决于常数项(有没有逆元)

代码和多项式开根放一起

多项式开根

仍然是倍增的思想,设原来的多项式是f,我们现已经开到了g,下一级是g'

那么$g'=\frac{f+g^2}{2g}$,边界是g[0]=1

所以需要求逆做前置科技

注意求逆的数组不要搞混了

(其实这是我写小朋友和二叉树时候顺便改的)

  1 #include<cmath>
  2 #include<cstdio>
  3 #include<vector>
  4 #include<cstring>
  5 #include<algorithm>
  6 #define o 1ll
  7 #define vint vector<int>
  8 #define vit vector<int> ::iterator
  9 using namespace std;
 10 const int N=400005,mod=998244353;
 11 int n,m,rd,G,Gi,inv2;
 12 int a1[N],b1[N],mor[N];
 13 int a2[N],b2[N],tor[N],mos[N];
 14 int rev[N],pw[30][2]; vint g;
 15 int Qpow(int x,int k)
 16 {
 17     if(k==1) return x;
 18     int tmp=Qpow(x,k>>1);
 19     return o*tmp*tmp%mod*((k&1)?x:1)%mod;
 20 }
 21 void Pre()
 22 {
 23     G=3,Gi=Qpow(G,mod-2),inv2=Qpow(2,mod-2);
 24     for(int i=1;i<=24;i++)
 25     {
 26         pw[i][0]=Qpow(G,(mod-1)/(1<<i));
 27         pw[i][1]=Qpow(Gi,(mod-1)/(1<<i));
 28     }
 29 }
 30 int Prework(int len)
 31 {
 32     int ret=1; 
 33     while(ret<=len) ret<<=1;
 34     for(int i=0;i<=ret;i++)
 35         rev[i]=(rev[i>>1]>>1)+(i&1)*(ret>>1);
 36     return ret;
 37 }
 38 
 39 vint Oridec(vint v)
 40 {
 41     int sz=v.size();
 42     for(int i=0,t;i<sz;i++)
 43         t=v[i],v[i]=mod-t;
 44     return v;
 45 }
 46 vint Oriadd(vint v,int x)
 47 {
 48     v[0]+=x; return v;    
 49 }    
 50 vint Orimul(vint v,int x)
 51 {
 52     int sz=v.size();
 53     for(int i=0;i<sz;i++)
 54         v[i]=o*v[i]*x%mod;
 55     return v;
 56 }
 57 
 58 void Trans(int *arr,int len,int typ)
 59 {
 60     for(int i=0;i<=len;i++)
 61         if(rev[i]>i) swap(arr[rev[i]],arr[i]);
 62     for(int i=2;i<=len;i<<=1)
 63     {
 64         int lth=i>>1,ort=pw[(int)log2(i)][typ==-1];
 65         for(int j=0;j<len;j+=i)
 66         {
 67             int tmp,cal=1;
 68             for(int k=j;k<j+lth;cal=o*cal*ort%mod,k++)
 69             {
 70                 tmp=o*arr[k+lth]*cal%mod;
 71                 arr[k+lth]=(arr[k]-tmp+mod)%mod;
 72                 arr[k]=(arr[k]+tmp)%mod;
 73             }
 74         }
 75     }
 76     if(typ==-1)
 77     {
 78         int Ni=Qpow(len,mod-2);
 79         for(int i=0;i<=len;i++)
 80             arr[i]=o*arr[i]*Ni%mod;
 81     }
 82 }
 83 void Getinv(int len,int *ori,int *inv)
 84 {
 85     if(len==1) inv[0]=Qpow(ori[0],mod-2);
 86     else
 87     {
 88         Getinv((len+1)>>1,ori,inv);
 89         int lth=Prework(len<<1);
 90         for(int i=0;i<len;i++) mor[i]=ori[i];
 91         for(int i=len;i<lth;i++) mor[i]=0;
 92         Trans(mor,lth,1),Trans(inv,lth,1);
 93         for(int i=0;i<lth;i++) 
 94             inv[i]=(2+mod-o*inv[i]*mor[i]%mod)*inv[i]%mod;
 95         Trans(inv,lth,-1);
 96         for(int i=len;i<lth;i++) inv[i]=0;
 97     }
 98 }
 99 void Getsqr(int len,int *ori,int *sqr)
100 {
101     if(len==1) sqr[0]=1;
102     else
103     {
104         Getsqr((len+1)>>1,ori,sqr);
105         Getinv(len,sqr,tor);
106         int lth=Prework(len<<1);
107         for(int i=0;i<len;i++) mos[i]=ori[i];
108         for(int i=len;i<lth;i++) mos[i]=0;
109         Trans(mos,lth,1),Trans(sqr,lth,1),Trans(tor,lth,1);
110         for(int i=0;i<lth;i++)
111             sqr[i]=(o*tor[i]*mos[i]%mod+sqr[i])*inv2%mod;
112         Trans(sqr,lth,-1);
113         for(int i=len;i<lth;i++) sqr[i]=0;
114         for(int i=0;i<lth;i++) tor[i]=0;
115     }
116 }
117 vint Polyinv(vint v)
118 {
119     int sz=v.size();
120     for(int i=0;i<sz;i++) a1[i]=v[i];
121     Getinv(sz,a1,b1);
122     for(int i=0;i<sz;i++) v[i]=b1[i];
123     return v;
124 }
125 vint Polysqr(vint v)
126 {
127     int sz=v.size();
128     for(int i=0;i<sz;i++) a2[i]=v[i];
129     Getsqr(sz,a2,b2);
130     for(int i=0;i<sz;i++) v[i]=b2[i];
131     return v;
132 }
133 
134 int main()
135 {
136     scanf("%d",&n),Pre();
137     for(int i=1;i<=n;i++)
138         scanf("%d",&rd),g.push_back(rd);
139     vint sg=Polysqr(g);
140     for(int i=0;i<n;i++) printf("%d ",sg[i]);
141     return 0;
142 }
View Code

分治FFT

经常用在做背包的时候用,不是洛谷那个模板,那个严格来说是“CDQ-FFT”,这个是“DC-FFT”,我以前的代码可能经常搞混2333

比如这道题

  1 #include<cmath>
  2 #include<cstdio>
  3 #include<cctype>
  4 #include<cstring>
  5 #include<algorithm>
  6 using namespace std;
  7 const int N=800005,M=30,mod=99991;
  8 const int Pow=15,Bas=(1<<Pow)-1;
  9 const double pai=acos(-1);
 10 struct cpx
 11 {
 12     double x,y;
 13     void Turn(int a,int b)
 14     {
 15         x=a,y=b;
 16     }
 17 }a[N],b[N],c[N],d[N],ort[N];
 18 const cpx b1=(cpx){0.5,0};
 19 const cpx b2=(cpx){0,-0.5};
 20 const cpx b3=(cpx){1,0};
 21 const cpx b4=(cpx){0,1};
 22 cpx operator + (cpx a,cpx b)
 23 {
 24     return (cpx){a.x+b.x,a.y+b.y};
 25 } 
 26 cpx operator - (cpx a,cpx b)
 27 {
 28     return (cpx){a.x-b.x,a.y-b.y};
 29 }
 30 cpx operator * (cpx a,cpx b)
 31 {
 32     double x1=a.x,x2=b.x,y1=a.y,y2=b.y;
 33     return (cpx){x1*x2-y1*y2,x1*y2+x2*y1};
 34 }
 35 cpx operator ! (cpx a)
 36 {
 37     a.y=-a.y; return a;
 38 }
 39 char BF[1<<23],*P1=BF,*P2=BF;
 40 char Gc(){return (P1==P2&&(P2=(P1=BF)+fread(BF,1,1<<21,stdin),P1==P2)?EOF:*P1++);}
 41 template<class Type> void Fread(Type &x)
 42 {
 43     x=0; char ch=Gc();
 44     while(!isdigit(ch)) ch=Gc();
 45     while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=Gc();
 46 }
 47 int Roumod(double x)
 48 { 
 49     return (long long)(x+0.5)%mod;
 50 }
 51 int Qpow(int x,int k)
 52 {
 53     if(k==1) return x;
 54     int tmp=Qpow(x,k/2);
 55     return k%2?1ll*tmp*tmp%mod*x%mod:1ll*tmp*tmp%mod;
 56 }
 57 #define vint vector<int>
 58 #define vit vector<int> ::iterator 
 59 double Sin[M],Cos[M];
 60 int n,m,f0,f1,k1,k2,anss;
 61 int num[N],odf[N],rev[N],xx[N],yy[N],ans[N]; vint fuc;
 62 
 63 void Trans(cpx *cop,int len,int typ)
 64 {
 65     register int i,j,k;
 66     for(i=0;i<len;i++)
 67         if(rev[i]>i) swap(cop[i],cop[rev[i]]);
 68     for(i=2;i<=len;i<<=1)
 69     {
 70         int lth=i>>1;
 71         for(j=0;j<len;j+=i)
 72         {
 73             cpx *pts=ort;
 74             for(k=j;k<j+lth;pts+=len/lth,k++)
 75             {
 76                 cpx tmp=*pts; if(typ==-1) tmp=!tmp;
 77                 tmp=tmp*cop[k+lth],cop[k+lth]=cop[k]-tmp,cop[k]=cop[k]+tmp;
 78             }
 79         }
 80     }
 81     if(typ==-1)
 82         for(int i=0;i<=len;i++)     
 83             cop[i].x/=len,cop[i].y/=len;
 84 }
 85 void Mul(cpx *c1,cpx *c2,cpx &a1,cpx &a2,int p,int q)
 86 {
 87     cpx t1=(c1[p]+!c1[q])*b1,t2=(c1[p]-!c1[q])*b2;
 88     cpx t3=(c2[p]+!c2[q])*b1,t4=(c2[p]-!c2[q])*b2;
 89     a1=t1*t3+(t1*t4+t2*t3)*b4,a2=t2*t4;
 90 }
 91 void CDFFT(int *p1,int *p2,int *ans,int len)
 92 {
 93     register int i;
 94     for(i=0;i<len;i++)
 95     {
 96         a[i].Turn(p1[i]&Bas,p1[i]>>Pow),p1[i]=0;
 97         b[i].Turn(p2[i]&Bas,p2[i]>>Pow),p2[i]=0; 
 98     }
 99     Trans(a,len,1),Trans(b,len,1),Mul(a,b,c[0],d[0],0,0);
100     for(i=1;i<len;i++) Mul(a,b,c[i],d[i],i,len-i);
101     Trans(c,len,-1),Trans(d,len,-1);
102     for(i=0;i<len;i++) 
103     {
104         long long x1=Roumod(c[i].x),y1=Roumod(c[i].y),x2=Roumod(d[i].x);
105         ans[i]=(((x2<<(Pow<<1))+(y1<<Pow)+x1)%mod+mod)%mod;
106     }
107 }
108 vint Merge(vint v1,vint v2)
109 {
110     register int i;
111     vint ret; ret.clear();
112     int l1=v1.size()-1,l2=v2.size()-1,len=l1+l2;
113     for(i=0;i<=l1;i++) xx[i]=v1[i];
114     for(i=0;i<=l2;i++) yy[i]=v2[i];
115     int lth=1; while(lth<=len) lth<<=1; 
116     for(i=0;i<=lth;i++)
117     {
118         rev[i]=(rev[i>>1]>>1)+(i&1)*(lth>>1);
119         ort[i]=(cpx){cos(pai*i/lth),sin(pai*i/lth)};
120     }
121     CDFFT(xx,yy,ans,lth);
122     for(i=0;i<=len;i++) ret.push_back(ans[i]);
123     return ret;
124 }
125 vint Divide(int l,int r)
126 {
127     if(l==r)     
128     {
129         vint ret; ret.clear();
130         ret.push_back(1);
131         ret.push_back(odf[l]);
132         return ret;
133     }
134     else
135     {
136         int mid=(l+r)>>1;
137         vint ls=Divide(l,mid);
138         vint rs=Divide(mid+1,r);
139         return Merge(ls,rs);
140     } 
141 }
142 int main()
143 {
144     register int i;
145     Fread(n),Fread(m);
146     for(i=1;i<=n;i++) Fread(num[i]);
147     Fread(f0),Fread(f1);
148     k2=1ll*(f0+f1)*Qpow(4,mod-2)%mod,k1=(f0-k2+mod)%mod;
149     for(i=1;i<=n;i++) odf[i]=num[i]%2?(mod-1):1;
150     fuc=Divide(1,n),anss=1ll*fuc[m]*k1%mod;
151     for(i=1;i<=n;i++) odf[i]=Qpow(3,num[i]);
152     fuc=Divide(1,n),anss=(anss+1ll*fuc[m]*k2%mod)%mod;
153     printf("%d",anss);
154     return 0;
155 }
View Code

CDQ-FFT

后面的项依赖于前面的项的卷积,用CDQ的思想来做就好

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<algorithm>
 4 using namespace std;
 5 const int N=100005,mod=998244353;
 6 int a[4*N],b[4*N],rev[4*N],f[N],g[N],n,G,Gi;
 7 void exGCD(int a,int b,int &x,int &y)
 8 {
 9     if(!b) {x=1,y=0; return ;}
10     exGCD(b,a%b,y,x),y-=a/b*x;
11 }
12 int Qpow(int x,int k)
13 {
14     if(k==1) return x;
15     int tmp=Qpow(x,k/2);
16     return k%2?1ll*tmp*tmp%mod*x%mod:1ll*tmp*tmp%mod;
17 }
18 int Inv(int x,int m)
19 {
20     int xx,yy;
21     exGCD(x,m,xx,yy);
22     return (xx%m+m)%m;
23 }
24 void NTT(int *arr,int len,int typ)
25 {
26     for(int i=0;i<=len;i++)
27         if(rev[i]>i) swap(arr[rev[i]],arr[i]);
28     for(int i=2;i<=len;i<<=1)
29     {
30         int lth=i>>1,ort=Qpow(~typ?G:Gi,(mod-1)/i);
31         for(int j=0;j<len;j+=i)
32         {
33             int ori=1,tmp;
34             for(int k=j;k<j+lth;k++,ori=1ll*ori*ort%mod)
35             {
36                 tmp=1ll*ori*arr[k+lth]%mod;
37                 arr[k+lth]=(arr[k]-tmp+mod)%mod;
38                 arr[k]=(arr[k]+tmp)%mod;
39             }
40         }
41     }
42     if(typ==-1)
43         for(int i=0,ni=Inv(len,mod);i<len;i++) 
44             arr[i]=1ll*arr[i]*ni%mod;
45 }
46 void CDQ(int l,int r,int mid)
47 {
48     int len=r-l+1,m=1;
49     for(int i=l;i<=mid;i++) a[i-l]=f[i];
50     for(int i=0;i<len;i++) b[i]=g[i]; len+=mid-l+1;
51     while(m<=len) m<<=1;
52     for(int i=1;i<=m;i++) rev[i]=(rev[i>>1]>>1)+(i&1)*(m>>1);
53     NTT(a,m,1),NTT(b,m,1);
54     for(int i=0;i<=m;i++) a[i]=1ll*a[i]*b[i]%mod;
55     NTT(a,m,-1);
56     for(int i=mid+1;i<=r;i++) f[i]+=a[i-l],f[i]%=mod;
57     for(int i=0;i<=m;i++) a[i]=b[i]=0;
58 }
59 void Divide(int l,int r)
60 {
61     if(l==r) return;
62     int mid=(l+r)/2;
63     Divide(l,mid),CDQ(l,r,mid),Divide(mid+1,r);
64 }
65 int main()
66 {
67     scanf("%d",&n);
68     for(int i=1;i<n;i++) scanf("%d",&g[i]);
69     f[0]=1,G=3,Gi=Inv(G,mod),Divide(0,n-1);
70     for(int i=0;i<n;i++) printf("%d ",f[i]);
71     return 0;
72 }
洛谷板子,很久以前写的,很naive

生成函数

看原来的笔记吧

(上面那个分治FFT的例题其实用了生成函数

这种东西还是要用用才会用

如果纯看字面的式子可能有些反常识:为什么会自己转移到自己啊=。=???

例题 小朋友和二叉树

考虑一个朴素的枚举左右子树拼起来的DP,边界是一个点的时候方案是1

那么设$f$为树的权值的生成函数,再设一个生成函数g表示集合中是否有某一个数,则有$f=f^2g+1$

解得$f=\frac{1±\sqrt {1-4g}}{2g}$

答案是$f=\frac{1-\sqrt {1-4g}}{2g}$

你怎么知道我是答案?

我比较笨,所以采用亿泰的方法:两个都试一试

(或者推一推看看哪一个符合边界条件,然而我根本不会推)

然后这个东西就可以再化一化减少一点计算量:上下同时乘一个$1+\sqrt {1-4g}$

最后答案就是$\frac{2}{1+\sqrt {1-4g}}$

  1 #include<cmath>
  2 #include<cstdio>
  3 #include<vector>
  4 #include<cstring>
  5 #include<algorithm>
  6 #define o 1ll
  7 #define vint vector<int>
  8 #define vit vector<int> ::iterator
  9 using namespace std;
 10 const int N=400005,mod=998244353;
 11 int n,m,rd,G,Gi,inv2;
 12 int a1[N],b1[N],mor[N];
 13 int a2[N],b2[N],tor[N],mos[N];
 14 int rev[N],pw[30][2]; vint g;
 15 int Qpow(int x,int k)
 16 {
 17     if(k==1) return x;
 18     int tmp=Qpow(x,k>>1);
 19     return o*tmp*tmp%mod*((k&1)?x:1)%mod;
 20 }
 21 void Pre()
 22 {
 23     G=3,Gi=Qpow(G,mod-2),inv2=Qpow(2,mod-2);
 24     for(int i=1;i<=24;i++)
 25     {
 26         pw[i][0]=Qpow(G,(mod-1)/(1<<i));
 27         pw[i][1]=Qpow(Gi,(mod-1)/(1<<i));
 28     }
 29 }
 30 int Prework(int len)
 31 {
 32     int ret=1; 
 33     while(ret<=len) ret<<=1;
 34     for(int i=0;i<=ret;i++)
 35         rev[i]=(rev[i>>1]>>1)+(i&1)*(ret>>1);
 36     return ret;
 37 }
 38 
 39 vint Oriadd(vint v,int x)
 40 {
 41     v[0]+=x; return v;    
 42 }    
 43 vint Orimul(vint v,int x)
 44 {
 45     int sz=v.size();
 46     for(int i=0;i<sz;i++)
 47         v[i]=o*v[i]*x%mod;
 48     return v;
 49 }
 50 
 51 void Trans(int *arr,int len,int typ)
 52 {
 53     for(int i=0;i<=len;i++)
 54         if(rev[i]>i) swap(arr[rev[i]],arr[i]);
 55     for(int i=2;i<=len;i<<=1)
 56     {
 57         int lth=i>>1,ort=pw[(int)log2(i)][typ==-1];
 58         for(int j=0;j<len;j+=i)
 59         {
 60             int tmp,cal=1;
 61             for(int k=j;k<j+lth;cal=o*cal*ort%mod,k++)
 62             {
 63                 tmp=o*arr[k+lth]*cal%mod;
 64                 arr[k+lth]=(arr[k]-tmp+mod)%mod;
 65                 arr[k]=(arr[k]+tmp)%mod;
 66             }
 67         }
 68     }
 69     if(typ==-1)
 70     {
 71         int Ni=Qpow(len,mod-2);
 72         for(int i=0;i<=len;i++)
 73             arr[i]=o*arr[i]*Ni%mod;
 74     }
 75 }
 76 void Getinv(int len,int *ori,int *inv)
 77 {
 78     if(len==1) inv[0]=Qpow(ori[0],mod-2);
 79     else
 80     {
 81         Getinv((len+1)>>1,ori,inv);
 82         int lth=Prework(len<<1);
 83         for(int i=0;i<len;i++) mor[i]=ori[i];
 84         for(int i=len;i<lth;i++) mor[i]=0;
 85         Trans(mor,lth,1),Trans(inv,lth,1);
 86         for(int i=0;i<lth;i++) 
 87             inv[i]=(2+mod-o*inv[i]*mor[i]%mod)*inv[i]%mod;
 88         Trans(inv,lth,-1);
 89         for(int i=len;i<lth;i++) inv[i]=0;
 90     }
 91 }
 92 void Getsqr(int len,int *ori,int *sqr)
 93 {
 94     if(len==1) sqr[0]=1;
 95     else
 96     {
 97         Getsqr((len+1)>>1,ori,sqr);
 98         Getinv(len,sqr,tor);
 99         int lth=Prework(len<<1);
100         for(int i=0;i<len;i++) mos[i]=ori[i];
101         for(int i=len;i<lth;i++) mos[i]=0;
102         Trans(mos,lth,1),Trans(sqr,lth,1),Trans(tor,lth,1);
103         for(int i=0;i<lth;i++)
104             sqr[i]=(o*tor[i]*mos[i]%mod+sqr[i])*inv2%mod;
105         Trans(sqr,lth,-1);
106         for(int i=len;i<lth;i++) sqr[i]=0;
107         for(int i=0;i<lth;i++) tor[i]=0;
108     }
109 }
110 vint Polyinv(vint v)
111 {
112     int sz=v.size();
113     for(int i=0;i<sz;i++) a1[i]=v[i];
114     Getinv(sz,a1,b1);
115     for(int i=0;i<sz;i++) v[i]=b1[i];
116     return v;
117 }
118 vint Polysqr(vint v)
119 {
120     int sz=v.size();
121     for(int i=0;i<sz;i++) a2[i]=v[i];
122     Getsqr(sz,a2,b2);
123     for(int i=0;i<sz;i++) v[i]=b2[i];
124     return v;
125 }
126 
127 int main()
128 {
129     scanf("%d%d",&n,&m);
130     g.resize(m+1),Pre();
131     for(int i=1;i<=n;i++)
132     {
133         scanf("%d",&rd);
134         if(rd<=m) g[rd]=mod-4;
135     }
136     vint sg=Polysqr(Oriadd(g,1));
137     vint rg=Polyinv(Oriadd(sg,1)),g=Orimul(rg,2);
138     for(int i=1;i<=m;i++)
139         printf("%d\n",g[i]);
140     return 0;
141 }
View Code

再说一个例子吧

洛谷 4233 射命丸文的笔记

(不过这题可就不像上面那个那么裸了

显然要分开算,先算n个点带标号竞赛图里哈密顿回路的总数,然后除以n个点带标号强连通竞赛图的数目

总数怎么算?先钦定一个环来排列,这样环里每条边都被算了n次,剩下的边随便连,所以总数就是$(n-1)!2^{{C_n^2}-n}$

图的数目稍微麻烦一些

容斥来算,钦定在缩点后形成的DAG里拓扑序最小的SCC,然后剩下的那些点之间随便连,这个SCC和剩下的点之间的边方向固定。所以得到一个递推式

$f[i]=2^{C_i^2}-\sum\limits_{j=1}^{i-1}f[j]*C_i^j* 2^{C_{i-j}^2}$

出题人:随便推推就出来了

右边移到左边去

$\sum\limits_{j=1}^{i}f[j]*C_i^j* 2^{C_{i-j}^2}=2^{C_i^2}$

拆组合数并移项

$\sum\limits_{j=1}^i\frac{f[j]}{j!}*\frac{2^{C_{i-j}^2}}{(i-j)!}=\frac{2^{C_i^2}}{i!}$

这个时候已经可以CDQ-FFT了,不过我们可以更进一步推成求逆的式子再丢一个log

设生成函数$F(x)=\frac{f[x]}{x!},G(x)=\frac{2^{C_x^2}}{x!}$,那么有

$G=FG+1(+1$是边界啦)

所以$F=\frac{G-1}{G}$

  1 #include<cmath>
  2 #include<cstdio>
  3 #include<vector>
  4 #include<cstring>
  5 #include<algorithm>
  6 #define o 1ll
  7 #define vint vector<int>
  8 using namespace std;
  9 const int N=400005,mod=998244353;
 10 int n,G,Gi,rd,fac[N],inv[N],ans[N];
 11 int a[N],b[N],mem[N],rev[N],pw[30][2]; vint f,g;
 12 void Clean(int *arr,int siz){memset(arr,0,siz);}
 13 int Qpow(int x,int k)
 14 {
 15     if(k<=1) return k?x:1;
 16     int tmp=Qpow(x,k>>1);
 17     return o*tmp*tmp%mod*((k&1)?x:1)%mod;
 18 }
 19 
 20 int Prework(int len)
 21 {
 22     int lth=1;
 23     while(lth<=len) lth<<=1;
 24     for(int i=0;i<=lth;i++)
 25         rev[i]=(rev[i>>1]>>1)+(i&1)*(lth>>1);
 26     return lth;
 27 }
 28 void Trans(int *arr,int len,int typ)
 29 {
 30     register int i,j,k;
 31     for(i=0;i<len;i++)
 32         if(rev[i]>i) swap(arr[rev[i]],arr[i]);
 33     for(i=2;i<=len;i<<=1)
 34     {
 35         int lth=i>>1,ort=pw[(int)log2(i)][typ==-1];
 36         for(j=0;j<len;j+=i)
 37         {
 38             int ori=1,tmp;
 39             for(k=j;k<j+lth;k++,ori=o*ori*ort%mod)
 40             {
 41                 tmp=o*ori*arr[k+lth]%mod;
 42                 arr[k+lth]=(arr[k]-tmp+mod)%mod;
 43                 arr[k]=(arr[k]+tmp)%mod;
 44             }
 45         }
 46     }
 47     if(typ==-1)
 48     {
 49         int Ni=Qpow(len,mod-2);
 50         for(i=0;i<=len;i++)
 51             arr[i]=o*arr[i]*Ni%mod;
 52     }
 53 }
 54 void Getinv(int len,int *ori,int *inv)
 55 {
 56     if(len==1) inv[0]=Qpow(ori[0],mod-1);
 57     else
 58     {
 59         Getinv((len+1)>>1,ori,inv);
 60         int lth=Prework(len<<1);
 61         for(int i=0;i<len;i++) mem[i]=ori[i];
 62         for(int i=len;i<=lth;i++) mem[i]=0;
 63         Trans(mem,lth,1),Trans(inv,lth,1);
 64         for(int i=0;i<=lth;i++)
 65             inv[i]=(2-o*mem[i]*inv[i]%mod+mod)*inv[i]%mod;
 66         Trans(inv,lth,-1);
 67         for(int i=len;i<=lth;i++) inv[i]=0;
 68     }
 69 }
 70 vint Oridec(vint v,int x)
 71 {
 72     (v[0]+=mod-x)%=mod;
 73     return v;
 74 }
 75 vint Polyinv(vint v)
 76 {
 77     int len=v.size();
 78     for(int i=0;i<len;i++) a[i]=v[i];
 79     Getinv(len,a,b);
 80     for(int i=0;i<len;i++) v[i]=b[i];
 81     Clean(a,len*4),Clean(b,len*4);
 82     return v;
 83 }
 84 vint Polymul(vint x,vint y)
 85 {
 86     vint ret; ret.clear();
 87     int l1=x.size()-1,l2=y.size()-1,len=l1+l2;
 88     for(int i=0;i<=l1;i++) a[i]=x[i];
 89     for(int i=0;i<=l2;i++) b[i]=y[i];
 90     int lth=Prework(len);
 91     Trans(a,lth,1),Trans(b,lth,1);
 92     for(int i=0;i<lth;i++) a[i]=o*a[i]*b[i]%mod;
 93     Trans(a,lth,-1);
 94     for(int i=0;i<=len;i++) ret.push_back(a[i]);
 95     Clean(a,lth*4),Clean(b,lth*4);
 96     return ret;
 97 }
 98 
 99 void Pre()
100 {
101     fac[0]=inv[0]=1;
102     for(int i=1;i<=n;i++) 
103         fac[i]=o*fac[i-1]*i%mod;
104     inv[n]=Qpow(fac[n],mod-2);
105     for(int i=n-1;i;i--)
106         inv[i]=o*inv[i+1]*(i+1)%mod;
107     G=3,Gi=Qpow(3,mod-2);
108     for(int i=1;i<=24;i++)
109     {
110         pw[i][0]=Qpow(G,(mod-1)/(1<<i));
111         pw[i][1]=Qpow(Gi,(mod-1)/(1<<i));
112     }
113 }
114 void Calc()
115 {
116     g.resize(n+1);
117     for(int i=0;i<=n;i++) 
118         g[i]=o*Qpow(2,o*i*(i-1)/2%(mod-1))*inv[i]%mod;
119     vint fz=Oridec(g,1),fm=Polyinv(g),fuc=Polymul(fz,fm); 
120     ans[1]=1,ans[2]=-1;
121     for(int i=3;i<=n;i++)
122     {
123         int fz=o*fac[i-1]*Qpow(2,o*i*(i-3)/2%(mod-1))%mod; 
124         int fm=o*fuc[i]*fac[i]%mod;
125         ans[i]=o*fz*Qpow(fm,mod-2)%mod; if(!ans[i]) ans[i]=-1;
126     }
127 }
128 int main()
129 {
130     scanf("%d",&n),Pre(),Calc();
131     for(int i=1;i<=n;i++)
132         printf("%d\n",ans[i]);
133     return 0;
134 }
View Code

 

推荐阅读