首页 > 技术文章 > 最近公共祖先(LCA)

AKing- 2021-08-20 15:09 原文

最近公共祖先(LCA)

首先是最近公共祖先的概念 (什么是最近公共祖先? ): 在一棵没有环的树上,每个节点肯定有其父亲节点和祖先节点,而最近公共祖先,就是两个节点在这棵树上深度最大的公共的祖先节点。 换句话说,就是两个点在这棵树上距离最近的公共祖先节点

朴素算法

求最近公共祖先的方法有很多,想最简单的朴素算法,就是遍历,从两个点开始,每次让深度大的先往上方跳,那么在第一次遇见的点就是他们的最近公共祖先

前提是需要先dfs遍历一遍树,求出深度

倍增算法

由于我不会Tarjan 算法倍增算法通过预处理fa数组,可以快速的在树上向上方跳,建立fa数组fa[u][i]表示u的第2^i个祖先,对于求两点之间距离的问题还需要dist[u][i]表示u到第2^i个祖先的距离

代码对应 How far away ?

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;

typedef pair<int, int> PII;
typedef long long ll;

inline int read(void)//常见的读入
{
    register int x = 0;
    register short sgn = 1;
    register char c = getchar();
    while (c < 48 || 57 < c)
    {
        if (c == 45)
            sgn = 0;
        c = getchar();
    }
    while (47 < c && c < 58)
    {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return sgn ? x : -x;
}
inline void write(ll x)//没有特点的输出
{
    if (x < 0)
        putchar('-'), x = -x;
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}

const int N = 4e4 + 10, M = 2 * N;
int T, n, m;
int h[N], e[M], ne[M], w[M], idx;//邻接表存图
int fa[N][31], dist[N][31], dep[N];//fa[u][i]表示u的第2^i个祖先,dist[u][i]表示u到第2^i个祖先的距离,dep[u]表示u的深度

inline void add (int a, int b, int c) {//加边函数
	e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void dfs (int u, int father) {
	fa[u][0] = father;//当前点的第2^0个祖先就是我的父亲
	dep[u] = dep[fa[u][0]] + 1;//当前点的深度就是父亲的深度 + 1

	for (int i = 1;i < 31;i ++) {
        //u的第2^i个祖先,就是第2^(i - 1)个祖先的第2^(i - 1)个祖先
		fa[u][i] = fa[fa[u][i - 1]][i - 1];
        //u到第2^i个祖先的距离就是u到第2^(i - 1)个祖先的距离加上第2^(i - 1)个祖先到他的2^(i - 1)个祖先的距离
		dist[u][i] = dist[u][i - 1] + dist[fa[u][i - 1]][i - 1];
	}

	for (int i = h[u];~i;i = ne[i]) {//遍历子节点
		int j = e[i];
		if (j == father) continue;//由于是无向图,防止往回搜
		dist[j][0] = w[i];//从u到v,那么v到第2^0个祖先,也就是v到父亲的距离,就是这条边的长度
		dfs(j, u);
	}
}

int lca (int x, int y) {
	int res = 0;

	if(dep[x] > dep[y]) swap(x, y);//令y为深度大的点

	int tmp = dep[y] - dep[x];//得到深度差
	for (int j = 0;tmp;j ++, tmp >>= 1) {//二进制递减深度
		if(tmp & 1) res += dist[y][j], y = fa[y][j];//二进制累加、递推
	}

	if (x == y) return res;//如果x == y,证明x和y在同一个小子树上,直接返回结果

	//如果不在一个小子树上,继续往上找
	for (int i = 30;i >= 0;i --) {//这里从上往下,使得x和y到距离最近公共祖先最近的点
		if (fa[x][i] != fa[y][i]) {
			res += dist[x][i] + dist[y][i];
			x = fa[x][i];
			y = fa[y][i];
		}
	}

	res += dist[x][0] + dist[y][0];//最后把从x和y到最近公共祖先的距离也加上

	return res;
}

int main()
{
    //freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    T = read();
    while (T --) {
    	memset(h, -1, sizeof h);
    	n = read(), m = read();
    	for (int i = 1;i < n;i ++) {
    		int a = read(), b = read(), c = read();
    		add(a, b, c);//无向图,存两条边
    		add(b, a, c);
    	}
    	dfs(1, 0);
    	while (m --) {
    		int a = read(), b = read();
    		write(lca(a, b));
    		puts("");
    	}
    }

    return 0;
}

如果是为了求最近公共祖先的点,就可以把求距离的全部删掉了

代码对应 P3379 【模板】最近公共祖先(LCA)

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;

typedef pair<int, int> PII;
typedef long long ll;

inline int read(void)
{
    register int x = 0;
    register short sgn = 1;
    register char c = getchar();
    while (c < 48 || 57 < c)
    {
        if (c == 45)
            sgn = 0;
        c = getchar();
    }
    while (47 < c && c < 58)
    {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return sgn ? x : -x;
}
inline void write(ll x)
{
    if (x < 0)
        putchar('-'), x = -x;
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}

const int N = 4e4 + 10, M = 2 * N;
int T, n, m;
int h[N], e[M], ne[M], idx;//存图
int fa[N][31], dist[N][31], dep[N];//fa[u][i]表示u的第2^i个祖先,dist[u][i]表示u到第2^i个祖先的距离,dep[u]表示u的深度

inline void add (int a, int b) {
	e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs (int u, int father) {
	fa[u][0] = father;//当前点的第2^0个祖先就是我的父亲
	dep[u] = dep[fa[u][0]] + 1;//当前点的深度就是父亲的深度 + 1

	for (int i = 1;i < 31;i ++) {
		fa[u][i] = fa[fa[u][i - 1]][i - 1];//u的第2^i个祖先,就是第2^(i - 1)个祖先的第2^(i - 1)个祖先
	}

	for (int i = h[u];~i;i = ne[i]) {//遍历子节点
		int j = e[i];
		if (j == father) continue;
		dfs(j, u);
	}
}

int lca (int x, int y) {

	if(dep[x] > dep[y]) swap(x, y);//令y为深度大的点

	int tmp = dep[y] - dep[x];//得到深度差
	for (int j = 0;tmp;j ++, tmp >>= 1) {//二进制递减深度
		if(tmp & 1) y = fa[y][j];
	}

	if (x == y) return x;

	//如果不在一个小子树上,继续往上找
	for (int i = 30;i >= 0;i --) {//这里从上往下,使得x和y到距离最近公共祖先最近的点
		if (fa[x][i] != fa[y][i]) {
			x = fa[x][i];
			y = fa[y][i];
		}
	}

	return fa[x][0];
}

int main()
{
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    // T = read();
    // while (T --) {
    	memset(h, -1, sizeof h);
    	n = read(), m = read();
    	int root = read();
    	for (int i = 1;i < n;i ++) {
    		int a = read(), b = read();
    		add(a, b);
    		add(b, a);
    	}
    	dfs(root, 0);
    	while (m --) {
    		int a = read(), b = read();
    		write(lca(a, b));
    		puts("");
    	}
    // }

    return 0;
}

等我学习了Tarjan算法我还会回来的

推荐阅读