首页 > 技术文章 > 「算法笔记」多项式求逆

maoyiting 2021-03-13 09:36 原文

一、基本思路

先给出多项式求逆的定义:

多项式求逆:给定一个 \(n-1\) 次多项式 \(A(x)\),求一个多项式 \(B(x)\) 满足:\(A(x)B(x)\equiv 1\pmod {x^n}\)。模 \(x^n\) 即只考虑多项式的 \(x^0,x^1,\cdots,x^{n-1}\) 项。

考虑倍增。假设已经求出了多项式 \(A(x)\) 在模 \(x^{\lceil \tfrac{n}{2}\rceil}\) 意义下的逆元 \(B'(x)\),那么有:

\[A(x)B'(x) \equiv 1\pmod{x^{\lceil \frac{n}{2}\rceil}} \]

因为 \(A(x)B(x)\equiv 1\pmod {x^n}\),所以显然也有:\(A(x)B(x)\equiv 1\pmod{x^{\lceil \tfrac{n}{2}\rceil}}\)

将两式相减得到:\(A(x)(B(x)-B'(x))\equiv 0\pmod{x^{\lceil\tfrac{n}{2}\rceil}}\)

由于 \(A(x)\nmid x^{\lceil \tfrac{n}{2}\rceil}\),那么有:\(B(x)-B'(x)\equiv 0\pmod{x^{\lceil\tfrac{n}{2}\rceil}}\)

两边同时平方得:\(B(x)^2-2B(x)B'(x)+B'(x)^2\equiv 0\pmod {x^n}\)

移项得:\(B(x)^2\equiv 2B(x)B'(x)-B'(x)^2\pmod {x^n}\)

在两边同时乘上 \(A(x)\) 得到:

\[A(x)B(x)^2\equiv 2A(x)B(x)B'(x)-A(x)B'(x)^2\pmod {x^n} \]

通过逆元的定义 \(A(x)B(x)\equiv 1 \pmod {x^n}\) 可以化简为:

\[B(x)\equiv 2B'(x)-A(x)B'(x)^2\pmod {x^n} \]

二、具体实现

可以 递归 求解,递归边界为 \(n=1\),此时答案为 常数项的逆元。

也可以 迭代 实现(枚举迭代长度),常数较小。

时间复杂度 \(T(n)=T(\tfrac{n}{2})+\mathcal O(n\log n)=\mathcal O(n\log n)\)

//Luogu P4238
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e6+5,mod=998244353;
int n,f[N],a[N],b[N],res[N],tmp[N],len,r[N],inv;
int mul(int x,int n,int mod){
    int ans=mod!=1;
    for(x%=mod;n;n>>=1,x=x*x%mod)
        if(n&1) ans=ans*x%mod;
    return ans;
}
void NTT(int a[N],int n,int opt){    //opt=1/-1: DFT/IDFT
    for(int i=0;i<n;i++)
        if(i<r[i]) swap(a[i],a[r[i]]);
    for(int k=2;k<=n;k<<=1){
        int m=k>>1,x=mul(3,(mod-1)/k,mod),w=1,v; 
        if(opt==-1) x=mul(x,mod-2,mod);
        for(int i=0;i<n;i+=k,w=1)
            for(int j=i;j<i+m;j++) v=w*a[j+m]%mod,a[j+m]=(a[j]-v+mod)%mod,a[j]=(a[j]+v)%mod,w=w*x%mod;
    }
    if(opt==-1){
        inv=mul(len,mod-2,mod);
        for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
    }
} 
void solve(int A[N],int B[N],int n,int m,int res[N]){
    for(len=1;len<=n+m;len<<=1); 
    for(int i=0;i<len;i++)
        r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0),a[i]=b[i]=0;
    for(int i=0;i<n;i++) a[i]=A[i];
    for(int i=0;i<m;i++) b[i]=B[i];
    NTT(a,len,1),NTT(b,len,1);
    for(int i=0;i<len;i++) a[i]=((2*b[i]%mod-a[i]*b[i]%mod*b[i]%mod)%mod+mod)%mod;
    NTT(a,len,-1);
    for(int i=0;i<n;i++) res[i]=a[i];
}
void invp(int a[N],int n,int res[N]){
    if(n==1){res[0]=mul(a[0],mod-2,mod);return ;}
    invp(a,(n+1)/2,res);
    solve(a,res,n,n,res);    //注意这里写 solve(a,res,n,(n+1)/2,res) 会出错。原因:B(x)≡2B′(x)−A(x)B′(x)^2 这个地方,最高次数可以达到 2n 次,如果 solve(a,res,n,(n+1)/2,res) 这样算出来的点值个数可能不到 2n 个,所以没有办法 IDFT 得到正确的多项式。
}
signed main(){
    scanf("%lld",&n);
    for(int i=0;i<n;i++)
        scanf("%lld",&f[i]);
    invp(f,n,res);
    for(int i=0;i<n;i++)
        printf("%lld%c",res[i],i==n-1?'\n':' ');
    return 0;
}

 2022.1.28 重学写了一遍迭代的板子:

//Luogu P4238
#include<bits/stdc++.h>
using namespace std;
const int N=(1<<18)+5,mod=998244353;
int n,a[N],b[N],len,r[N],A[N],B[N],p[N],q[N];
int qpow(int x,int n){
    int ans=1;
    for(;n;n>>=1,x=1ll*x*x%mod) if(n&1) ans=1ll*ans*x%mod;
    return ans;
}
void NTT(int *a,int n,int op){
    for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int k=2;k<=n;k<<=1){
        int m=k>>1,v=qpow(~op?3:qpow(3,mod-2),(mod-1)/k),w=1,x,y; 
        for(int i=0;i<n;i+=k,w=1)
            for(int j=i;j<i+m;j++)
                x=a[j],y=1ll*w*a[j+m]%mod,a[j]=(x+y)%mod,a[j+m]=(x-y+mod)%mod,w=1ll*w*v%mod;
    }
    if(op==-1){
        int inv=qpow(n,mod-2);
        for(int i=0;i<n;i++) a[i]=1ll*a[i]*inv%mod;
    }
} 
void mul(int *a,int *b,int len){
    for(int i=0;i<len;i++)
        p[i]=a[i],q[i]=b[i],r[i]=r[i>>1]>>1|(i&1?len>>1:0);
    NTT(p,len,1),NTT(q,len,1);
    for(int i=0;i<len;i++) a[i]=1ll*p[i]*q[i]%mod;
    NTT(a,len,-1);
}
void inv(int *a,int *b,int len){
    for(int i=0;i<len;i++) b[i]=0;
    b[0]=qpow(a[0],mod-2);
    for(int i=2;i<=len;i<<=1){
        for(int j=0;j<(i<<1);j++) A[j]=j<i?a[j]:0,B[j]=j<(i>>1)?b[j]:0; 
        //A(x)B'(x)^2 长度最大为 i*2
        mul(B,B,i),mul(A,B,i<<1);
        for(int j=0;j<i;j++) b[j]=(2ll*b[j]%mod-A[j]+mod)%mod;
    }
}
signed main(){
    scanf("%d",&n);
    for(int i=0;i<n;i++) scanf("%d",&a[i]);
    for(len=1;len<=n;len<<=1);
    inv(a,b,len);
    for(int i=0;i<n;i++) printf("%d ",b[i]);
    return 0;
}

 

推荐阅读