首页 > 技术文章 > FTT简单入门板子

zpj61 2021-01-27 10:58 原文

DFT :

 1 #include <cstdio>
 2 #include <iostream>
 3 #include <cmath>
 4 #include <complex>
 5 typedef double db;
 6 typedef long long ll;
 7 
 8 #define com complex<db>
 9 using namespace std;
10 const int N=1e5+10;
11 const db pi=acos(-1);
12 int n,h[N*6],M;
13 com q[N*6],r[N*6];
14 com a[N*6];
15 void dft(com *src,int sig) {
16     for (int i=0; i<M; i++) a[h[i]]=src[i]; //蝶形变换的准备
17     for (int m=2; m<=M; m<<=1) { //正在求的组大小
18         int half=m>>1;
19         for (int i=0; i<half; i++) { //求A[i]与A[i+k],先枚举这个方便处理主根
20             com w=com(cos(i*2*pi/m) , sig * sin(i*2*pi/m));
21             //必须一步一求,不然精度会出锅。
22             //由于转移到m,因此按照式子是m次根。
23             for (int j=i; j<M; j+=m) { //第几组
24                 int k=j+half;
25                 com u=a[j], v=a[k]*w;
26                 a[j]=u+v,a[k]=u-v;
27             }
28         }
29     }
30     for (int i=0; i<M; i++) src[i]=a[i];
31 }
32 int main() {
33     //freopen("3617.in","r",stdin);
34     cin>>n;
35     for (int i=0; i<n; i++) scanf("%lf",&q[i].real());
36     for (M=1; M<3*n; M<<=1);
37     for (int i=0; i<2*n-1; i++) {
38         if (i==n-1) r[i]=0; else 
39         r[i]=pow(i-(n-1),-2) * (i<n-1?-1:1);
40     }
41     for (int i=0; i<M; i++) h[i]=(h[i>>1]>>1) + ((i&1) * (M>>1));
42     //以次数界-1为长度,翻转二进制
43     
44     dft(q,1); dft(r,1);
45     for (int i=0; i<M; i++) q[i]=q[i]*r[i];
46     dft(q,-1);
47     for (int i=n-1; i<n+n-1; i++) printf("%lf\n",q[i].real() / M);
48     //不要忘记除次数界!!!
49 }

 

NFT:

#include <cstdio>
#include <iostream>
using namespace std;
const int N=4e5+10,mo=998244353,g=3;
typedef long long ll;

int n,m,M;
int A[N],B[N],h[N];
ll w[N], iw[N];

ll ksm(ll x,ll y) {
    if (y==0) return 1;
    if (y==1) return x;
    ll t=ksm(x,y>>1);
    return t*t%mo*ksm(x,y&1)%mo;
}
void ntt(int *a,int sz,int sig) {
    for (int i = 1; i < sz; i++) 
        h[i] = (h[i>>1]>>1) + (i & 1) * (sz >> 1);
    for (int i = 0; i <sz; i++) 
        if (h[i]<i) swap(a[i],a[h[i]]);

    for (int m = 1; m < sz; m<<=1) {
        int td = M / (m<<1);
        for (int i = 0; i < sz; i += (m<<1)) {
            for (int j = 0; j < m; j++) {
                ll T = a[i+j+m] * (sig == 1 ? w[td * j] : iw[td * j]) % mo;
                a[i+j+m] = (a[i+j] - T) % mo;
                a[i+j]   = (a[i+j] + T) % mo;
            }
        }
    }
}

int main() {
    freopen("test.in","r",stdin);
    cin>>n>>m;;
    for (int i=0; i<=n; i++) scanf("%d",&A[i]);
    for (int i=0; i<=m; i++) scanf("%d",&B[i]);
    for (M=1; M<=n+m; M<<=1);
    for (int i=1; i<M; i++) 
        h[i]=(h[i>>1]>>1) + (i&1) * (M>>1);
    ll ww = ksm(3, (mo - 1) / M);
    iw[0] = w[0] = 1;
    for (int i = 1; i < M; i++) w[i] = w[i-1] * ww % mo;
    ww = ksm(ww, mo - 2);
    for (int i = 1; i < M; i++) iw[i] = iw[i-1] * ww % mo;

    ntt(A,M,1);
    ntt(B,M,1);
    for (int i=0; i<M; i++) A[i]=(ll)A[i]*B[i]%mo;
    ntt(A,M,-1);
    ll cs=ksm(M,mo-2);
    for (int i=0; i<=n+m; i++) printf("%lld ",(A[i]*cs%mo+mo)%mo);
}

 FFT:

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int maxn=270010;
const double pi=acos(-1.0),eps=1e-4;
struct Complex{
    double a,b;
    Complex(double a=0.0,double b=0.0):a(a),b(b){}
    Complex operator+(const Complex &x)const{return Complex(a+x.a,b+x.b);}
    Complex operator-(const Complex &x)const{return Complex(a-x.a,b-x.b);}
    Complex operator*(const Complex &x)const{return Complex(a*x.a-b*x.b,a*x.b+b*x.a);}
}A[maxn],B[maxn];
void FFT(Complex*,int,int);
int n,m,N=1;
int main(){
    scanf("%d%d",&n,&m);
    n++;m++;
    while(N<n+m)N<<=1;
    for(int i=0;i<n;i++)scanf("%lf",&A[i].a);
    for(int i=0;i<m;i++)scanf("%lf",&B[i].a);
    FFT(A,N,1);
    FFT(B,N,1);
    for(int i=0;i<N;i++)A[i]=A[i]*B[i];
    FFT(A,N,-1);
    for(int i=0;i<n+m-1;i++)printf("%d ",(int)(A[i].a+eps));
    return 0;
}
void FFT(Complex *A,int n,int tp){
    for(int i=1,j=0;i<n-1;i++){
        int k=N;
        do{
            k>>=1;
            j^=k;
        }while(j<k);
        if(i<j)swap(A[i],A[j]);
    }
    for(int k=2;k<=n;k<<=1){
        Complex wn(cos(-tp*2*pi/k),sin(-tp*2*pi/k));
        for(int i=0;i<n;i+=k){
            Complex w(1.0,0.0);
            for(int j=0;j<(k>>1);j++,w=w*wn){
                Complex a(A[i+j]),b(w*A[i+j+(k>>1)]);
                A[i+j]=a+b;
                A[i+j+(k>>1)]=a-b;
            }
        }
    }
    if(tp<0)for(int i=0;i<n;i++)A[i].a/=n;
}

 

推荐阅读