首页 > 技术文章 > Codeforces 204E Little Elephant and Strings 后缀数组 + 并查集

CJLHY 2019-07-04 15:45 原文

按lcp合并计算答案。

#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 2e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = 998244353;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;}

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int r[N], sa[N], _t[N], _t2[N], c[N], rk[N], lcp[N], san;
int maxc = 'z' + 1;

void buildSa(int *r, int n, int m) {
    int i, j = 0, k = 0, *x = _t, *y = _t2;
    for(i = 0; i < m; i++) c[i] = 0;
    for(i = 0; i < n; i++) c[x[i] = r[i]]++;
    for(i = 1; i < m; i++) c[i] += c[i - 1];
    for(i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i;
    for(int k = 1; k <= n; k <<= 1) {
        int p = 0;
        for(i = n - k; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
        for(i = 0; i < m; i++) c[i] = 0;
        for(i = 0; i < n; i++) c[x[y[i]]]++;
        for(i = 1; i < m; i++) c[i] += c[i - 1];
        for(i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
        swap(x, y);
        p = 1; x[sa[0]] = 0;
        for(int i = 1; i < n; i++) {
            if(y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k])
                x[sa[i]] = p - 1;
            else x[sa[i]] = p++;
        }
        if(p >= n) break;
        m = p;
     }
     for(i = 1; i < n; i++) rk[sa[i]] = i;
     for(i = 0; i < n - 1; i++) {
        if(k) k--;
        j = sa[rk[i] - 1];
        while(r[i + k] == r[j + k]) k++;
        lcp[rk[i]] = k;
     }
}

int fa[N], dl[N], dr[N];

int n, k, maxH, belong[N], to[N], cnt[N];
LL ans[N], tmp[N];
char s[N];

vector<int> V[N];
vector<int> ID;

int getRoot(int x) {
    return x == fa[x] ? x : fa[x] = getRoot(fa[x]);
}

int Merge(int u, int v) {
    int x = getRoot(u);
    int y = getRoot(v);
    fa[y] = x;
    chkmin(dl[x], dl[y]);
    chkmax(dr[x], dr[y]);
    return x;
}

void printSuf(int x) {
    for(int i = sa[x]; i <= san; i++) putchar((char)r[i]);
    for(int i = 0; i < (sa[x] + 3); i++) putchar(' ');
    printf("sa: %d  lcp: %d\n", sa[x], lcp[x]);
}

int main() {
    scanf("%d%d", &n, &k);
    for(int i = 1; i <= n; i++) {
        if(i > 1) r[san++] = maxc++;
        scanf("%s", s);
        int len = strlen(s);
        for(int j = 0; s[j]; j++) {
            r[san] = s[j];
            belong[san] = i;
            san++;
        }
        if(k == 1) ans[i] = 1LL * len * (len + 1) / 2;
    }

    if(k == 1) {
        for(int i = 1; i <= n; i++) {
            printf("%lld ", ans[i]);
        }
        puts("");
        return 0;
    }

    r[san] = 0;
    buildSa(r, san + 1, maxc);

    for(int i = 1; i <= san; i++) {
        V[lcp[i]].push_back(i);
        chkmax(maxH, lcp[i]);
    }

    int now = 0;

    for(int i = 1, j = 1; i <= san; i++) {
        j = max(j, i);
        while(now < k && j <= san) {
            if(belong[sa[j]]) {
                if(!cnt[belong[sa[j]]]) now++;
                cnt[belong[sa[j]]]++;
            }
            j++;
        }
        if(now < k) break;
        to[i] = j - 1;
        if(belong[sa[i]]) {
            if(cnt[belong[sa[i]]] == 1) now--;
            cnt[belong[sa[i]]]--;
        }
    }

    for(int i = 1; i <= san; i++) {
        fa[i] = dl[i] = dr[i] = i;
    }

    for(int i = maxH; i > 0; i--) {
        ID.clear();
        for(auto &t : V[i]) {
            ID.push_back(Merge(t - 1, t));
        }
        for(auto &t : ID) {
            t = getRoot(t);
        }
        sort(ALL(ID));
        ID.erase(unique(ALL(ID)), ID.end());
        for(auto &t : ID) {
            if(!to[dl[t]] || dr[t] < to[dl[t]]) continue;
            int x = lcp[dl[t]];
            if(dr[t] + 1 <= san) chkmax(x, lcp[dr[t] + 1]);
            tmp[dl[t]] += i - x;
            tmp[dr[t] + 1] -= i - x;
        }
    }

    for(int i = 1; i <= san; i++) {
        tmp[i] += tmp[i - 1];
        ans[belong[sa[i]]] += tmp[i];
    }

    for(int i = 1; i <= n; i++) {
        printf("%lld ", ans[i]);
    }
    puts("");
    return 0;
}

/*
*/

 

推荐阅读