首页 > 技术文章 > [十二省联考2019]字符串问题

Fermat 2020-08-01 11:43 原文

现有一个字符串 \(S\)

Tiffany 将从中划出 \(n_a\) 个子串作为 \(A\) 类串,第 \(i\) 个(\(1 \leqslant i \leqslant n_a\))为 \(A_i = S(la_i, ra_i)\)

类似地,Yazid 将划出 \(n_b\) 个子串作为 \(B\) 类串,第 \(i\) 个(\(1 \leqslant i \leqslant n_b\))为 \(B_i = S(lb_i, rb_i)\)

现额外给定 \(m\) 组支配关系,每组支配关系 \((x, y)\) 描述了第 \(x\)\(A\) 类串支配\(y\)\(B\) 类串。

求一个长度最大的目标串 \(T\),使得存在一个串 \(T\) 的分割 \(T = t_1+t_2+· · ·+t_k\)\(k \geqslant 0\))满足:

  • 分割中的每个串 \(t_i\) 均为 \(A\) 类串:即存在一个与其相等的 \(A\) 类串,不妨假设其为 \(t_i = A_{id_i}\)
  • 对于分割中所有相邻的串 \(t_i, t_{i+1}\)\(1 \leqslant i < k\)),都有存在一个\(A_{id_i}\) 支配的 \(B\) 类串,使得该 \(B\) 类串为 \(t_{i+1}\) 的前缀。

方便起见,你只需要输出这个最大的长度即可。

特别地,如果存在无限长的目标串(即对于任意一个正整数 \(n\),都存在一个满足限制的长度超过 \(n\) 的串),请输出 \(-1\)

首先这个问题可以转化,每个A向可支配的B连边,B向符合条件的A(是A的前缀)连边,然后topsort,bfs,如果有环一定是-1,否则就是bfs出的最长路
那么连边中的难点在与B如何向那些A连边。考虑先对A,B的每个子串在SAM上找到对应的结点位置,然后在同一个结点中按长度从大到小为第一关键字,是A类串为第二关键字。
然后在结点中倒序循环,当前长度最长的B向其他点连边。父亲结点中的最长的B结点向当前点连边。当前结点向所有小于等于当前结点最小B节点的点连边。(听着很麻烦,实现并不复杂)
代码是抄的

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
const int N = 8e5 + 100;
typedef long long ll;
char s[N];
int n, na, nb, cnt, tot, last, sz, m;
int ch[N][26], fa[N], f[N][26];
int in[N], a[N], b[N], len[N], pos[N];
int lst[N], h[N], isa[N];
ll dis[N];
struct edge{
    int to,next;
}e[N << 1];
vector<int>g[N];

void add(int u, int v) {
    e[++cnt] = (edge) {v, h[u]}; h[u]= cnt;
    in[v]++;
}

void insert(int x) {
    int p = last, np = ++tot; last = np;
    len[np] = len[p] + 1;
    for(; p && !ch[p][x]; p = fa[p]) ch[p][x] = np;
    if (!p) return void(fa[np] = 1);
    int q = ch[p][x];
    if (len[q] == len[p] + 1) return void(fa[np] = q);
    int nq = ++tot;
    fa[nq] = fa[q];
    len[nq] = len[p] + 1;
    memcpy(ch[nq], ch[q], sizeof(ch[q]));
    fa[q] = fa[np] = nq;
    for (;p && ch[p][x] == q; p = fa[p]) ch[p][x] = nq;
    return ;
}

void find(int o) {
    int l, r ;
    scanf("%d %d", &l, &r);
    r = r - l + 1;l = pos[l];
    for (int i = 19; i >= 0; i--) 
        if (f[l][i] && len[f[l][i]] >= r) l = f[l][i];
    isa[++sz] = o;
    len[sz] = r;
    g[l].push_back(sz);
}

bool cmp(int x,int y) {
    return (len[x] > len[y]) || (len[x] == len[y] && isa[x] > isa[y]);
}



void solve() {
    scanf("%s", s + 1);
    last = tot = 1;
    n = strlen(s + 1);
    for (int i = n; i; i--) insert(s[i] - 'a'), pos[i] = last;
    for (int i = 1; i <= tot; i++) f[i][0] = fa[i];
    for (int i = 1;i <= 19; i++) 
        for (int j = 1; j <= tot; j++) 
            f[j][i] = f[f[j][i-1]][i-1];
    scanf("%d", &na);sz= tot;
    for (int i = 1;i <= na; i++) find(1), a[i] = sz;
    scanf("%d", &nb); 
    for (int i = 1; i<= nb; i++) find(0), b[i] = sz;
    for (int i = 1; i <= tot; i++) sort(g[i].begin(), g[i].end(), cmp);
    for (int i = 1; i <= tot; i++) {
        int last = i;
        for (int j = (int)g[i].size() - 1; j >= 0; j--) {
            int now = g[i][j];
            add(last, now);
            if (!isa[now]) last = now;
        }
        lst[i] = last;
    }
    for (int i = 2; i <= tot; i++) add(lst[fa[i]], i);
    for (int i = 1; i <= sz; i++) if (!isa[i]) len[i] = 0;
    scanf("%d", &m);
    for (int i = 1; i <= m; i++) {
        int x, y;
        scanf("%d %d", &x, &y);
        add(a[x], b[y]);
    }
    bool flag = 0;
    ll ans = 0;
    queue<int> q;
    for (int i = 1; i <= sz; i++) if (!in[i]) q.push(i);
    while (!q.empty()) {
        int j = q.front();
        q.pop();
        ans = max(ans, dis[j] + len[j]);
        for (int i = h[j]; i; i = e[i].next) {
            int v = e[i].to;
            dis[v] = max(dis[v], dis[j] + len[j]);
            in[v]--;
            if (!in[v]) q.push(v);
        }
    }
    for (int i = 1; i <= sz; i++) if (in[i]) flag = 1;
    if (flag) puts("-1");
    else printf("%lld\n", ans);
    for (int i = 1; i <= tot; i++) fa[i] = 0, memset(ch[i], 0, sizeof(ch[i]));
    for (int i = 1; i <= sz; i++) in[i] = h[i] = dis[i] = len[i] = isa[i] =  0, g[i] = vector<int>();
    last = sz = tot = cnt = 0;
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) solve();
}

推荐阅读