知识点: 广义 SAM
可以阅读这篇文章:「笔记」广义后缀自动机
原题面 Luogu
题意简述
求 \(n\) 个字符串 \(S_1\sim S_n\)的本质不同子串个数。
\(1\le n\le 4\times 10^5, 1\le \sum|S_i|\le 10^6\)。
分析题意
这里只说一下 Insert 函数与普通 SAM 不同的部分:
if (ch[last_][c_]) {
int p = last_, q = ch[p][c_];
if (len[p] + 1 == len[q]) return q;
int newq = ++ node_num;
memcpy(ch[newq], ch[q], sizeof(ch[q]));
len[newq] = len[p] + 1;
link[newq] = link[q];
link[q] = newq;
for (; p && ch[p][c_] == q; p = link[p]) ch[p][c_] = newq;
return newq;
}
如果存在 ch[last_][c_]!=0
,那么 last_
和 ch[last_][c_]
都是插入 上一个串 时被添加的节点,需要考虑信息的合并。
设 p = last_, q = ch[p][c_]
。
如果 len[p] + 1 == len[q]
,说明在 p
之后插入新字符的得到的状态,在 SAM 中已经有一个与之完全相同的状态 q
了,可以直接对应过去。
此时状态 q
内含有多个串的 \(\operatorname{endpos}\) 信息。
否则,说明状态 q
内包含了新状态的信息,也包含着原串其他信息,考虑分裂状态 q
,产生新状态 newq
,来储存新串新状态的信息。
最后返回代表新串新状态的 newq
作为新的 last_
。
代码实现
//知识点:SAM
/*
By:Luckyblock
试了试变量写法,挺清爽的。
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define ll long long
const int kMaxn = 1e6 + 10;
const int kMaxm = 26;
//=============================================================
ll ans;
char S[kMaxn];
int node_num = 1, ch[kMaxn << 1][kMaxm], len[kMaxn <<1], link[kMaxn << 1];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
int Insert(int c_, int last_) {
if (ch[last_][c_]) {
int p = last_, q = ch[p][c_];
if (len[p] + 1 == len[q]) return q;
int newq = ++ node_num;
memcpy(ch[newq], ch[q], sizeof(ch[q]));
len[newq] = len[p] + 1;
link[newq] = link[q];
link[q] = newq;
for (; p && ch[p][c_] == q; p = link[p]) ch[p][c_] = newq;
return newq;
}
int p = last_, now = ++ node_num;
len[now] = len[p] + 1;
for (; p && ! ch[p][c_]; p = link[p]) ch[p][c_] = now;
if (! p) {link[now] = 1; return now;}
int q = ch[p][c_];
if (len[q] == len[p] + 1) {link[now] = q; return now;}
int newq = ++ node_num;
memcpy(ch[newq], ch[q], sizeof(ch[q]));
link[newq] = link[q], len[newq] = len[p] + 1;
link[q] = link[now] = newq;
for (; p && ch[p][c_] == q; p = link[p]) ch[p][c_] = newq;
return now;
}
//=============================================================
int main() {
int T = read();
while (T --) {
scanf("%s", S + 1);
int n = strlen(S + 1), last = 1;
for (int i = 1; i <= n; ++ i) {
last = Insert(S[i] - 'a', last);
}
}
for (int i = 2; i <= node_num; ++ i) {
ans += len[i] - len[link[i]];
}
printf("%lld\n", ans);
return 0;
}