最近公共祖先(LCA)
首先是最近公共祖先的概念 (什么是最近公共祖先? ): 在一棵没有环的树上,每个节点肯定有其父亲节点和祖先节点,而最近公共祖先,就是两个节点在这棵树上深度最大的公共的祖先节点。 换句话说,就是两个点在这棵树上距离最近的公共祖先节点
朴素算法
求最近公共祖先的方法有很多,想最简单的朴素算法,就是遍历,从两个点开始,每次让深度大的先往上方跳,那么在第一次遇见的点就是他们的最近公共祖先
前提是需要先dfs遍历一遍树,求出深度
倍增算法
由于我不会Tarjan 算法倍增算法通过预处理fa数组,可以快速的在树上向上方跳,建立fa数组fa[u][i]表示u的第2^i个祖先
,对于求两点之间距离的问题还需要dist[u][i]表示u到第2^i个祖先的距离
代码对应 How far away ?
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef pair<int, int> PII;
typedef long long ll;
inline int read(void)//常见的读入
{
register int x = 0;
register short sgn = 1;
register char c = getchar();
while (c < 48 || 57 < c)
{
if (c == 45)
sgn = 0;
c = getchar();
}
while (47 < c && c < 58)
{
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return sgn ? x : -x;
}
inline void write(ll x)//没有特点的输出
{
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
const int N = 4e4 + 10, M = 2 * N;
int T, n, m;
int h[N], e[M], ne[M], w[M], idx;//邻接表存图
int fa[N][31], dist[N][31], dep[N];//fa[u][i]表示u的第2^i个祖先,dist[u][i]表示u到第2^i个祖先的距离,dep[u]表示u的深度
inline void add (int a, int b, int c) {//加边函数
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void dfs (int u, int father) {
fa[u][0] = father;//当前点的第2^0个祖先就是我的父亲
dep[u] = dep[fa[u][0]] + 1;//当前点的深度就是父亲的深度 + 1
for (int i = 1;i < 31;i ++) {
//u的第2^i个祖先,就是第2^(i - 1)个祖先的第2^(i - 1)个祖先
fa[u][i] = fa[fa[u][i - 1]][i - 1];
//u到第2^i个祖先的距离就是u到第2^(i - 1)个祖先的距离加上第2^(i - 1)个祖先到他的2^(i - 1)个祖先的距离
dist[u][i] = dist[u][i - 1] + dist[fa[u][i - 1]][i - 1];
}
for (int i = h[u];~i;i = ne[i]) {//遍历子节点
int j = e[i];
if (j == father) continue;//由于是无向图,防止往回搜
dist[j][0] = w[i];//从u到v,那么v到第2^0个祖先,也就是v到父亲的距离,就是这条边的长度
dfs(j, u);
}
}
int lca (int x, int y) {
int res = 0;
if(dep[x] > dep[y]) swap(x, y);//令y为深度大的点
int tmp = dep[y] - dep[x];//得到深度差
for (int j = 0;tmp;j ++, tmp >>= 1) {//二进制递减深度
if(tmp & 1) res += dist[y][j], y = fa[y][j];//二进制累加、递推
}
if (x == y) return res;//如果x == y,证明x和y在同一个小子树上,直接返回结果
//如果不在一个小子树上,继续往上找
for (int i = 30;i >= 0;i --) {//这里从上往下,使得x和y到距离最近公共祖先最近的点
if (fa[x][i] != fa[y][i]) {
res += dist[x][i] + dist[y][i];
x = fa[x][i];
y = fa[y][i];
}
}
res += dist[x][0] + dist[y][0];//最后把从x和y到最近公共祖先的距离也加上
return res;
}
int main()
{
//freopen("in.txt", "r", stdin);
//freopen("out.txt", "w", stdout);
T = read();
while (T --) {
memset(h, -1, sizeof h);
n = read(), m = read();
for (int i = 1;i < n;i ++) {
int a = read(), b = read(), c = read();
add(a, b, c);//无向图,存两条边
add(b, a, c);
}
dfs(1, 0);
while (m --) {
int a = read(), b = read();
write(lca(a, b));
puts("");
}
}
return 0;
}
如果是为了求最近公共祖先的点,就可以把求距离的全部删掉了
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef pair<int, int> PII;
typedef long long ll;
inline int read(void)
{
register int x = 0;
register short sgn = 1;
register char c = getchar();
while (c < 48 || 57 < c)
{
if (c == 45)
sgn = 0;
c = getchar();
}
while (47 < c && c < 58)
{
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return sgn ? x : -x;
}
inline void write(ll x)
{
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
const int N = 4e4 + 10, M = 2 * N;
int T, n, m;
int h[N], e[M], ne[M], idx;//存图
int fa[N][31], dist[N][31], dep[N];//fa[u][i]表示u的第2^i个祖先,dist[u][i]表示u到第2^i个祖先的距离,dep[u]表示u的深度
inline void add (int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs (int u, int father) {
fa[u][0] = father;//当前点的第2^0个祖先就是我的父亲
dep[u] = dep[fa[u][0]] + 1;//当前点的深度就是父亲的深度 + 1
for (int i = 1;i < 31;i ++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];//u的第2^i个祖先,就是第2^(i - 1)个祖先的第2^(i - 1)个祖先
}
for (int i = h[u];~i;i = ne[i]) {//遍历子节点
int j = e[i];
if (j == father) continue;
dfs(j, u);
}
}
int lca (int x, int y) {
if(dep[x] > dep[y]) swap(x, y);//令y为深度大的点
int tmp = dep[y] - dep[x];//得到深度差
for (int j = 0;tmp;j ++, tmp >>= 1) {//二进制递减深度
if(tmp & 1) y = fa[y][j];
}
if (x == y) return x;
//如果不在一个小子树上,继续往上找
for (int i = 30;i >= 0;i --) {//这里从上往下,使得x和y到距离最近公共祖先最近的点
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int main()
{
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
// T = read();
// while (T --) {
memset(h, -1, sizeof h);
n = read(), m = read();
int root = read();
for (int i = 1;i < n;i ++) {
int a = read(), b = read();
add(a, b);
add(b, a);
}
dfs(root, 0);
while (m --) {
int a = read(), b = read();
write(lca(a, b));
puts("");
}
// }
return 0;
}
等我学习了Tarjan算法我还会回来的