首页 > 技术文章 > 隔壁小孩教会了我NTT

Paranoid5 2021-10-27 19:14 原文

\(NTT\)

\(FFT\)使用复数单位根对\(DFT\)进行优化,\(NTT\)则使用了另外一种方式优化。
这种方式被称之为原根。
使用单位根时我们会进行大量的浮点计算,这不光让程序的运行时间大大增加,还会带来很大的进度误差。而原根则没有这样的问题。
除此之外\(NTT\)还解决了多项式乘法带模数的情况。

原根

\(a,p\)互素,\(p>1\)
对于\(a^n\equiv1(\mod p)\)最小的\(n\),我们称之为 \(a\)\(p\)的阶,记做\(δ_p(a)\)

原根

定义:

\(p\)是正整数,\(a\)是整数,若\(δ_p(a)\)等于\(\phi(a)\),则\(a\)为模\(p\)的 原根。

性质1:如果一个数字\(p\)有原根,那么它有\(\phi (\phi(p))\)个原根。
性质2:模\(p\)有原根的充要条件\(n=2,4,p^\alpha,p^{2\alpha}\)\(p\)为奇素数。
性质3\(p\)为素数,假设\(g\)\(p\)的原根,那么\(g^i \mod p,(i<p)\) 唯一。

除此之外,\(FFT\)中单位根满足的所有性质原根也满足。
所以我们认为\(g^{\frac{p-1}{n}}\)等价于\(e^{\frac{-2\pi i}{n}}\)
\(NTT\)中,\(p\)通常取998244353。原根为3。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <stack>
using namespace std;
#define ll long long

const int mod = 998244353, G = 3, Gi = 332748118;//这里的Gi是G的除法逆元
const ll N = 5e6+50;
ll n,m;
ll limit = 1;//二进制位数
ll L,R[N];//二进制位数、二进制翻转数组

ll f_pow(ll a,ll b){
    ll res = 1;
    while(b){
        if(b&1)res = res*a%mod;
        a = a*a%mod;
        b>>=1;
    }
    return res%mod;
}
ll a[N],b[N];
void ntt(ll *A,ll type){
    for(ll i = 0;i < limit;i++)if(i < R[i])swap(A[i],A[R[i]]);
    for(ll mid = 1;mid < limit;mid<<=1){
        ll wn = f_pow(G,(mod-1)/(mid<<1));//原根
        if(type == -1) wn = f_pow(wn,mod-2);
        for(ll len = mid<<1,pos = 0;pos < limit;pos+=len){
            ll w = 1;
            for(ll k = 0;k<mid;k++,w = w*wn%mod){
                //原根的操作与单位根类似
                ll x = A[pos+k],y = w*A[pos+k+mid]%mod;
                A[pos+k] = (x+y)%mod;
                A[pos+k+mid] = (x-y+mod)%mod;
            }
        }
    }
    if(type == 1)return ;
    //依然是除n,但是这里需要求逆元
    ll inv_lim = f_pow(limit,mod-2);
    for(ll i = 0;i < limit;i++) A[i] = A[i]*inv_lim%mod;
}
string p,q;
stack<ll> st;
int main() {
    cin>>p>>q;n = p.length()-1,m = q.length()-1;
    for(ll i = n;i >= 0;i--)a[i] = p[i]-'0';
    for(ll i = m;i >= 0;i--)b[i] = q[i]-'0';
    while(limit <= n+m)limit<<=1,L++;//长度
    for(ll i = 0;i < limit;i++){
        R[i] = (R[i>>1]>>1) | ((i&1)<<(L-1));
        //在原序列中i与i/2的关系是:i是i/2的左移
        //那么反转之后就需要右移,同时处理尾数
    }
    ntt(a,1);
    ntt(b,1);
    for(ll i = 0;i <= limit;i++)a[i] = a[i]*b[i];
    ntt(a,-1);//逆变换

    for(ll i = n+m;i > 0;i--){
        ll tmp = a[i];
        //cout<<a[i]<<endl;
        st.push(tmp%10);
        a[i-1] += tmp/10;
    }
    cout<<a[0];
    while(!st.empty()){
        cout<<st.top();
        st.pop();
    }
    return 0;
}

模板:

namespace NTT{
    const int P[3] = {469762049, 998244353, 167772161},//模数
    G = 3, //原根
    Gi[3] = {P[0] / G + 1, P[1] / G + 1, P[2] / G + 1};//逆元
    int R[N];
    inline int f_pow(int x, int y) {
        int ans = 1;
        while(y) {
            if(y & 1) ans = 1ll*ans * x % P[1];
            y >>= 1; x = 1ll*x * x % P[1];
        }
        return ans%P[1];
    }
    inline void reduce(int &x){//看到有人这么写但是没有变快
        x+=x>>31&P[1];
    }
    inline void ntt(int A[],int limit,int type){
        for(int i = 0;i < limit;i++)if(i < R[i])std::swap(A[i],A[R[i]]);
        for(int mid = 1;mid < limit;mid<<=1){
            int wn = f_pow(type == 1?G:Gi[1],(P[1]-1)/(mid<<1));//原根
            // if(type == -1) wn = f_pow(wn,mod-2);
            for(int len = mid<<1,pos = 0;pos < limit;pos+=len){
                int w = 1;
                for(int k = 0;k<mid;k++,w = 1ll*w*wn%P[1]){
                    //原根的操作与单位根类似
                    int x = A[pos+k],y = 1ll*w*A[pos+k+mid]%P[1];
                    A[pos+k] = (x+y)%P[1];
                    A[pos+k+mid] = (x-y+P[1])%P[1];
                }
            }
        }
        if(type == 1)return ;
        ll inv_lim = f_pow(limit,P[1]-2);
        for(ll i = 0;i < limit;i++) A[i] = 1ll*A[i]*inv_lim%P[1];
    }
    inline void mul(int a[],int n,int b[],int m){
        int limit = 1, L = 0;
        while (limit <= n+m)limit <<= 1, L++;
        for (int i = 0; i < limit; i++)R[i] = R[i >> 1] >> 1 | ((i & 1) << (L - 1));
        ntt(a, limit, 1);ntt(b, limit, 1);
        for (int i = 0; i <= limit; i++) a[i] = 1ll*a[i] * b[i] % P[1];
        ntt(a, limit, -1);
    }
}
using NTT::mul;

推荐阅读