首页 > 技术文章 > 数论变换(NTT)

zjjws 2020-12-23 17:46 原文

前置知识(?)

FFT


算法流程

和 FFT 的思想是一样的,它们的关系就像是高斯消元中的 有模数和无模数


注意事项

选取模数

需满足\(p=2^l\times k+1\),其中 \(l,k\in \mathbb{N_+}\)

对于非取模题

可以根据最终答案范围来决定,若是小于模数,直接做即可。若是那种 保证在 long long 范围内 的题,可以跑两到三遍 NTT,然后 Crt 合并。

对于取模题

需要给定的模数满足上述条件。

否则,有三种解决方法:

  • 跑多遍 NTT,然后合并。

  • 魔改一下 FFT,但是要及其注意精度问题。

  • 把出题人阿掉。


【模板】A*B Problem升级版

#include <stdio.h>
#define LL long long
using namespace std;
const int N=3e6+3;
const int M=998244353;
const int K=3;
inline int prpr(int x,int y){return 1LL*x*y%M;}
inline int ksm(int x,int y){int ans=1;for(;y;y>>=1){if(y&1)ans=prpr(ans,x);x=prpr(x,x);}return ans;}
inline void jh(int &x,int &y){x^=y^=x^=y;return;}
inline int gc(){int c=getchar();return c;}

const int Kr=ksm(K,M-2);

int n,m;
int a[N];
int b[N];

inline void Read()
{
    n=-1;m=-1;
    int x=gc();
    for(;x>='0'&&x<='9';x=gc())a[++n]=x-'0';
    for(;x<'0'||x>'9';x=gc());
    for(;x>='0'&&x<='9';x=gc())b[++m]=x-'0';
    
    for(int i=0;(i<<1)<n;i++)jh(a[i],a[n-i]);
    for(int i=0;(i<<1)<m;i++)jh(b[i],b[m-i]);
    return;
}

int num[N];
int lens;
inline void init()
{
    int lg=0;
    int ls=n+m+1;
    for(lens=1;lens<ls;lens<<=1)lg++;
    for(int i=1;i<lens;i++)num[i]=((num[i>>1]>>1)|((i&1)<<lg-1));
    return;
}

inline void NTT(int *a,bool bj)
{
    for(int i=1;i<lens;i++)if(i<num[i])jh(a[i],a[num[i]]);
    for(int i=1;i<lens;i<<=1)
    {
        int gyq=ksm((bj?K:Kr),(M-1)/(i<<1));
        for(int j=0;j<lens;j+=(i<<1))
        {
            int zjj=1;
            for(int k=0;k<i;k++,zjj=prpr(zjj,gyq))
            {
                int x=a[j+k],y=prpr(a[j+k+i],zjj);
                a[j+k]=(x+y)%M;
                a[j+k+i]=(x-y+M)%M;
            }
        }
    }
    if(!bj)
    {
        int gyq=ksm(lens,M-2);
        for(int i=0;i<lens;i++)a[i]=prpr(a[i],gyq);
    }
    return;
}

int main()
{
    int i,j;
    Read();
    // for(int i=n;i>=0;i--)printf("%d",a[i]);printf("\n");
    // for(int i=m;i>=0;i--)printf("%d",b[i]);printf("\n");
    init();
    NTT(a,1);
    NTT(b,1);
    for(int i=0;i<lens;i++)a[i]=prpr(a[i],b[i]);
    NTT(a,0);
    for(int i=0;i<lens;i++)a[i+1]+=a[i]/10,a[i]%=10;
    for(;a[lens];lens++)a[lens+1]+=a[lens]/10,a[lens]%=10;
    for(;lens&&!a[lens-1];lens--);
    for(int i=lens-1;i>=0;i--)printf("%d",a[i]);
    if(!lens)printf("0");
    printf("\n");
    return 0;
}

最后附个 质数&原根 表(从忘记了什么地方复制过来的)。

//(g 是mod(r*2^k+1)的原根)
素数  r  k  g
3   1   1   2
5   1   2   2
17  1   4   3
97  3   5   5
193 3   6   5
257 1   8   3
7681    15  9   17
12289   3   12  11
40961   5   13  3
65537   1   16  3
786433  3   18  10
5767169 11  19  3
7340033 7   20  3
23068673    11  21  3
104857601   25  22  3
167772161   5   25  3
469762049   7   26  3
1004535809  479 21  3
2013265921  15  27  31
2281701377  17  27  3
3221225473  3   30  5
75161927681 35  31  3
77309411329 9   33  7
206158430209    3   36  22
2061584302081   15  37  7
2748779069441   5   39  3
6597069766657   3   41  5
39582418599937  9   42  5
79164837199873  9   43  5
263882790666241 15  44  7
1231453023109121    35  45  3
1337006139375617    19  46  3
3799912185593857    27  47  5
4222124650659841    15  48  19
7881299347898369    7   50  6
31525197391593473   7   52  3
180143985094819841  5   55  6
1945555039024054273 27  56  5
4179340454199820289 29  57  3

推荐阅读