首页 > 技术文章 > 树链剖分

kylinbalck 2018-10-30 15:55 原文

树链剖分


树这个结构,本身很优美,但有些涉及区间的问题用树来做就会比较别扭。有一些方便处理区间的数据结构,可以和树结合一下。比如线段树。

定义

重儿子:所有儿子中节点数最多的儿子

轻儿子:除重儿子以外的儿子

重边:链接重儿子的链

轻边:剩余的边

重链:链接重儿子的链

轻链:其余的链

思想

将树按照一条条链剖开,就可以用线段树处理了。

如图,红色为重链,黑色为轻链。

如何维护?

两遍\(dfs\)

//第一遍
int son[100005],sz[100005],dep[100005],f[100005];
//重儿子         子树大小    深度        父亲节点
void dfs(int u,int fa){
    f[u]=fa,sz[u]=1,dep[u]=dep[fa]+1;//记录
    int ma=0;//存重儿子大小
    for(int i=head[u],v;v=a[i].to,i;i=a[i].next){//遍历每条相连的边
        if(v==fa) continue;//如果是父亲,跳过
        dfs(v,u);
        sz[u]+=sz[v];//统计子树大小
        if(sz[v]>ma) son[u]=v,ma=sz[v];//不断更新重儿子
    }
}
//第二遍
int top[100005],pt[100005],dfn[100005],num;
//   链头 / 该节点对应的dfs序 / 该dfs序对应的节点 / dfs序
void dfs(int u,int fa,int tp){//tp是链头
    top[u]=tp,pt[u]=++num,dfn[num]=u;
    if(!son[u]) return;//若没有儿子,就返回
    dfs(son[u],u,tp);//递归搜索重儿子,链头不变
    for(int i=head[u],v;v=a[i].to,i;i=a[i].next){
        if(v==fa||v==son[u]) continue;
        dfs(v,u,v);//搜索轻儿子,链头为它自己
    }
}

两遍\(dfs\)之后,一棵树就被我们抽筋剥骨变成若干条不相干的链了。

然后我们就可以用线段树维护一些乱七八糟的东西了。

提供一道例题

树的统计

题目描述

一棵树上有\(n\)个节点,编号分别为\(1\)\(n\),每个节点都有一个权值\(w\)

我们将以下面的形式来要求你对这棵树完成一些操作:

$I. $$CHANGE$ \(u\) \(t\): 把结点\(u\)的权值改为\(t\)

\(II. QMAX\) \(u\) \(v\): 询问从点\(u\)到点\(v\)的路径上的节点的最大权值

\(III. QSUM\) \(u\) \(v\): 询问从点\(u\)到点\(v\)的路径上的节点的权值和

注意:从点\(u\)到点\(v\)的路径上的节点包括\(u\)\(v\)本身

输入输出格式

输入格式:

输入文件的第一行为一个整数\(n\),表示节点的个数。

接下来\(n – 1\)行,每行\(2\)个整数\(a\)\(b\),表示节点\(a\)和节点\(b\)之间有一条边相连。

接下来一行\(n\)个整数,第\(i\)个整数\(wi\)表示节点\(i\)的权值。

接下来\(1\)行,为一个整数\(q\),表示操作的总数。

接下来\(q\)行,每行一个操作,以“\(CHANGE\) \(u\) \(t\)”或者\(“QMAX\) \(u\) \(v”\)或者\(“QSUM\) \(u\) \(v”\)的形式给出。

输出格式:

对于每个\(“QMAX”\)或者\(“QSUM”\)的操作,每行输出一个整数表示要求输出的结果。

输入输出样例

输入样例

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

输出样例

4
1
2
2
10
6
5
6
5
16

很显然的树链剖分

细节见代码

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
long long read(){
    long long x=0;int f=0;char c=getchar();
    while(c<'0'||c>'9')f|=c=='-',c=getchar();
    while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return f?-x:x;
}
int n,q,w[100005];
struct Dier{//结构体存边
    int to,next;
}a[100005];
int head[100005],cnt;
void add(int x,int y){//前向星
    a[++cnt].to=y,a[cnt].next=head[x],head[x]=cnt;
}

//对树 抽筋剥骨
int son[100005],sz[100005],dep[100005],f[100005];
void dfs(int u,int fa){//第一遍dfs
    f[u]=fa,sz[u]=1,dep[u]=dep[fa]+1;
    int ma=0;
    for(int i=head[u],v;v=a[i].to,i;i=a[i].next){
        if(v==fa) continue;
        dfs(v,u);
        sz[u]+=sz[v];
        if(sz[v]>ma) son[u]=v,ma=sz[v];
    }
}
int top[100005],pt[100005],dfn[100005],num;
void dfs(int u,int fa,int tp){//第二遍dfs
    top[u]=tp,pt[u]=++num,dfn[num]=u;
    if(!son[u]) return;
    dfs(son[u],u,tp);
    for(int i=head[u],v;v=a[i].to,i;i=a[i].next){
        if(v==fa||v==son[u]) continue;
        dfs(v,u,v);
    }
}

//线段树
struct xtm{
    int maxx,sum;//结构体存线段树
    xtm(){maxx=-30000;}
}t[120005];//线段树要开四倍空间,被神仙嘲笑++
#define lc p<<1//左儿子
#define rc p<<1|1//右儿子
void pushup(int p){//更新上面的点
    t[p].sum=t[lc].sum+t[rc].sum;//由左儿子和右儿子得来
    t[p].maxx=max(t[lc].maxx,t[rc].maxx);
}
void build(int p,int l,int r){//初始化线段树
    if(l==r){
        //dfn[]记录的是当前序号对应的是几号节点,我们记录的左右区间及dfs序,因此直接用l找点即可
        //被神仙嘲笑++
        t[p].maxx=t[p].sum=w[dfn[l]];return;
    }
    int m=(l+r)>>1;
    build(lc,l,m),build(rc,m+1,r);//线段树常规更新
    pushup(p);
}
void updata(int p,int l,int r,int k,int x){//更改某点的值
    if(l==r){
        t[p].sum=t[p].maxx=x;return;
    }
    int m=(l+r)>>1;
    if(m>=k) updata(lc,l,m,k,x);
    else updata(rc,m+1,r,k,x);
    pushup(p);
}

//求最大值
int find_max(int p,int l,int r,int L,int R){
    if(l>R||r<L) return -30000;//细节,被神仙嘲笑++
    if(l>=L&&r<=R) return t[p].maxx;
    int m=(l+r)>>1;
    return max(find_max(lc,l,m,L,R),find_max(rc,m+1,r,L,R));
}
void get_max(){
    int ans=-30000,x=read(),y=read();
    while(top[x]!=top[y]){//只要两个点不在一条链上,我们就可以让链头深度低的往上跳,直到两点在一条链上
        //一开始用lca的我,被神仙嘲笑++
    	if(dep[top[x]]>dep[top[y]]) swap(x,y);
    	ans=max(ans,find_max(1,1,n,pt[top[y]],pt[y])),y=f[top[y]];//直接跳到链头父节点
    }
    if(dep[x]>dep[y]) swap(x,y);
    //这里x和y已经在一条链上了,因此要判断两点的深度,而非链头深度,被神仙嘲笑++
    ans=max(ans,find_max(1,1,n,pt[x],pt[y]));
    printf("%d\n",ans);
}

//求和
int find_sum(int p,int l,int r,int L,int R){
    if(l>R||r<L) return 0;
    if(l>=L&&r<=R) return t[p].sum;
    int m=(l+r)>>1;
    return find_sum(lc,l,m,L,R)+find_sum(rc,m+1,r,L,R);
}
void get_sum(){
    int ans=0,x=read(),y=read();
    while(top[x]!=top[y]){//同上
    	if(dep[top[x]]>dep[top[y]]) swap(x,y);
    	ans+=find_sum(1,1,n,pt[top[y]],pt[y]),y=f[top[y]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=find_sum(1,1,n,pt[x],pt[y]);
    printf("%d\n",ans);
}

int main(){
    n=read();
    for(int i=1,x,y;i<n;++i){//读入,连边
        x=read(),y=read();
        add(x,y),add(y,x);
    }
    for(int i=1;i<=n;++i) w[i]=read();
    dfs(1,0),dfs(1,0,1),build(1,1,n);//预处理
    q=read();
    string s;
    while(q--){
        cin>>s;
        if(s[0]=='C'){int u=read(),t=read();updata(1,1,n,pt[u],t);}
        else if(s[1]=='M') get_max();
        else if(s[1]=='S') get_sum();
    }
    //cout<<被神仙嘲笑;GG
    return 0;
}

欢迎指正评论O(∩_∩)O~~

推荐阅读