首页 > 技术文章 > CF666E Forensic Examination

GK0328 2020-11-10 16:04 原文

CF666E Forensic Examination

广义\(SAM\)+线段树合并

这里利用到\(SAM\)\(endpos\)集合的性质,我们可以注意到,在\(SAM\)\(parent\)树上,祖先节点的\(endpos\)集合是包含子孙\(endpos\)集合的。

同时,所有子孙节点的\(endpos\)集合两两之间无交集,所有子孙节点的\(endpos\)集合的并集恰好包括了字符串中的每一个位置。利用这个性质,我们可以通过线段树合并直接直接求出所有节点的\(endpos\)集合。

我们阅读本题,包含了两个区间,看起来很吓人的样子,实际上\(S\)的区间是很容易消去的。

首先,对\(T\)建立广义\(SAM\),那么我们考虑用\(S\)去匹配,在\(S\)\(r\)位置,我们在\(SAM\)中匹配到的后缀最长长度是可以计算的,同时我们可以记录我们在\(SAM\)上匹配的位置\(p\)

我们需要找到\(S[l,r]\)在哪个\(endpos\)集合中,我们最终在该节点上统计答案。根据\(SAM\)的性质,这个\(endpos\)集合恰好包括了所有询问串出现的位置,当然这个节点不只代表一个串,但是对于只考虑串的位置来说,这些串是等价的。

如何找到这个位置,显然不能暴力跳\(parent\)树,我们考虑倍增,我们需要从\(p\)开始,找到一个\(len \ge r-l+1\)的深度最浅的位置。

我们利用权值线段树记录所有后缀能够匹配的最深节点,当然不是直接记录\(endpos\)集合,因为我们不需要这个,我们只需要记录这个后缀属于哪个\(T\)串即可。

那么我们在查询时,询问\(T[l,r]\)之间的最大值,同时不断进行线段树合并即可。

细节:判断不存在时输出\(l,0\),有两种情况。

\(1.\)不存在该串。

\(2.\)该串在询问区间中不出现。

特判即可。

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector>
#define N 500005
#define M 50005
using namespace std;
struct ask
{
    int id,l,r;
    ask (int Id=0,int L=0,int R=0)
    {
        id=Id,l=L,r=R;
    }
};
struct node
{
    int x,y;
    node (int xx=0,int yy=0)
    {
        x=xx,y=yy;
    }
    bool operator < (const node &A) const
    {
        return (x==A.x)?(y>A.y):(x<A.x);
    }
}t2[M << 5],ans[N];
vector<ask>e[M << 1];
#define IT vector<ask> :: iterator
int m,q,lst=1,tot=1;
int l,r,pl,pr;
int v[N],w[N];
int f[M << 1][21];
int tr[M << 1][26];
int cnt=0,t[M << 5],ls[M << 5],rs[M << 5],tt[M << 1];
int pre[M << 1],len[M << 1];
string s,T[M];
struct edge
{
    int nxt,v;
    edge (int Nxt=0,int V=0)
    {
        nxt=Nxt,v=V;
    }
}E[M << 1];
int ct,fr[M << 1];
void add(int x,int y)
{
    ++ct;
    E[ct]=edge(fr[x],y),fr[x]=ct;
}
void update(int p)
{
    t2[p]=max(t2[ls[p]],t2[rs[p]]);
}
void modify(int &p,int l,int r,int x)
{   
    if (!p)
        p=++cnt;
    if (l==r)
    {
        ++t[p];
        t2[p]=node(t[p],l);
        return;
    }
    int mid=(l+r) >> 1;
    if (x<=mid)
        modify(ls[p],l,mid,x); else
        modify(rs[p],mid+1,r,x);
    update(p);
}
int conbine(int x,int y,int l,int r)
{
    if (!x || !y)
        return x|y;
    if (l==r)
    {
        t[x]+=t[y];
        t2[x]=node(t[x],l);
        return x;
    }
    int mid=(l+r) >> 1;
    ls[x]=conbine(ls[x],ls[y],l,mid);
    rs[x]=conbine(rs[x],rs[y],mid+1,r);
    update(x);
    return x;
}
node calc(int p,int l,int r,int x,int y)
{
    if (!p)
        return t2[0];
    if (l==x && r==y)
        return t2[p];
    int mid=(l+r) >> 1;
    if (y<=mid)
        return calc(ls[p],l,mid,x,y); else
    if (x>mid)
        return calc(rs[p],mid+1,r,x,y); else
        return max(calc(ls[p],l,mid,x,mid),calc(rs[p],mid+1,r,mid+1,y));
}
void ic(int c)
{
    int np,p,q,g;
    if (tr[lst][c])
    {
        int p=lst,q=tr[lst][c];
        if (len[p]+1==len[q])
            lst=q; else
            {
                g=++tot;
                memcpy(tr[g],tr[q],sizeof(tr[q]));
                len[g]=len[p]+1,pre[g]=pre[q];
                for (;p && tr[p][c]==q;p=pre[p])
                    tr[p][c]=g;
                pre[q]=g;
                lst=g;
            }
        return;
    }
    np=++tot;
    p=lst;
    len[np]=len[p]+1;
    for (;p && !tr[p][c];p=pre[p])
        tr[p][c]=np;
    if (!p)
        pre[np]=1; else
        {
            q=tr[p][c];
            if (len[p]+1==len[q])
                pre[np]=q; else
                {
                    g=++tot;
                    memcpy(tr[g],tr[q],sizeof(tr[q]));
                    len[g]=len[p]+1,pre[g]=pre[q];
                    for (;p && tr[p][c]==q;p=pre[p])
                        tr[p][c]=g;
                    pre[np]=pre[q]=g;
                }
        }
    lst=np;
}
void ins(string &s)
{
    lst=1;
    int n=s.length();
    for (int i=0;i<n;++i)
        ic(s[i]-'a');
}
void solve(int u)
{
    for (int i=fr[u];i;i=E[i].nxt)
    {
        int v=E[i].v;
        solve(v);
        tt[u]=conbine(tt[u],tt[v],1,m);
    }
    for (IT it=e[u].begin();it!=e[u].end();++it)
    {
        ans[it->id]=calc(tt[u],1,m,it->l,it->r);
        if (ans[it->id].x==-1000000007)
            ans[it->id]=node(0,it->l);
    }
}
int main()
{
    cin >> s;
    scanf("%d",&m);
    for (int i=1;i<=m;++i)
        cin >> T[i],ins(T[i]);
    int n=s.length();
    int p=1,nl=0;
    for (int i=0;i<n;++i)
    {
        int c=s[i]-'a';
        if (tr[p][c])
            p=tr[p][c],++nl; else
            {
                while (p && !tr[p][c])
                    p=pre[p],nl=len[p];
                if (p)
                    p=tr[p][c],++nl;
            }
        v[i]=p,w[i]=nl;
        if (!p)
            p=1,nl=0;
    }
    for (int i=2;i<=tot;++i)
        f[i][0]=pre[i],add(pre[i],i);
    for (int j=1;j<=18;++j)
        for (int i=1;i<=tot;++i)
            f[i][j]=f[f[i][j-1]][j-1];
    scanf("%d",&q);
    for (int i=1;i<=q;++i)
    {
        scanf("%d%d%d%d",&l,&r,&pl,&pr);
        --pl,--pr;
        int p=v[pr],tl=pr-pl+1;
        if (!p || w[pr]<pr-pl+1)
            ans[i]=node(0,l); else
            {
                for (int j=18;j>=0;--j)
                    if (f[p][j] && len[f[p][j]]>=tl)
                        p=f[p][j];
                e[p].push_back(ask(i,l,r));
            }
    }
    t2[0]=node(-1000000007,1000000007);
    for (int i=1;i<=m;++i)
    {
        int p=1,n=T[i].length();
        for (int j=0;j<n;++j)
            p=tr[p][T[i][j]-'a'],modify(tt[p],1,m,i);
    }
    solve(1);
    for (int i=1;i<=q;++i)
        printf("%d %d\n",ans[i].y,ans[i].x);
    return 0;
}

推荐阅读