首页 > 技术文章 > 浅谈树链剖分

Neworld1111 2018-11-17 15:21 原文

传送门

捡起了之前树剖的坑,之前只写了树剖LCA的板子但是树剖模板却一直没有写好,学长sjie非常贴心地给我们上课,于是我去补了一下板子,并且写下自己对树剖的理解。


树剖用来干什么?

  • 维护树上的修改
    • 求u到v的路径上的所有点的权值和
    • 修改u到v的路径上的所有点的权值
    • 求以u为根的子树的结点的权值和
    • 修改以u为根的子树的结点的权值
    • LCA

树剖怎么实现?

  • 利用dfs序,将树上节点编上号,然后我们利用编号得到一个数组,数组存的就是它们的值,对于上述修改查询,利用数据结构维护即可。

  • 如图,求出dfs序(我们直接假设编号就是dfs序)(8号是7号,9号是8号,一不小心就画错图了……)
    • 修改4到6路径上的所有点,进行剖分:
      • 修改1到4
      • 修改5到6
    • 查询同理
    • 修改2的子树
      • 修改2到4,其实就是2到\(2+size[2]-1\)\(size\)是子树大小,定义请看下文
    • 查询同理
  • 但你会发现这样子效率不高啊……修改点到点的路径上的点,有可能会被剖分成好几段,那么单次修改的代价为\(O(N\log N)\),如下图

  • 对于这幅图,我想修改2到9
    • 修改1到2
    • 修改3
    • 修改5
    • 修改7
    • 修改9
  • 编号基本没有连续的,单次修改的收益很低
  • 所以我们需要一种优秀的dfs序,使得剖分时代价更低
  • 于是乎……

一些简单的定义

  • 子树大小:以u作为根的子树的结点数
  • 重儿子:对于一个结点u,其所有儿子中子树大小最大的那个儿子,如果子树大小相等,则可以任取一个
  • 轻儿子:不是重儿子的儿子
  • 重边:与重儿子的连边
  • 轻边:与轻儿子的连边
  • 重链:重边组成的链

graph

如图,2是1的重儿子,4是2的重儿子,9是4的重儿子,他们形成了一条重链,链头为1

7是3的重儿子,他们形成了一条重链,链头为3

从此我们可以得出一些结论:

  • 推论1:每个轻儿子都有以自己为链头的重链
  • 推论2:如果v是u的轻儿子,那么\(size[v] \leq size[u]/2\)
    • 如果不是这样的话,那么v就一定是重儿子
  • 推论3:根节点到任意节点的树链上轻边个数一定不大于\(\log_2N\)
    • 由推论2得,每次走上一条轻边,size一定小于原来的一半,相当于一直在二分二分二分,那么深度最多也只是\(\log_2N\)
    • 除了轻边就是重链了,所以重链的个数也不大于\(\log_2N\)
  • 推论4:一个节点到它的链头的编号是连续的,一个节点的子树的编号也是连续的

最后!我们的dfs序就是!!!重儿子优先

对于之前那张图,求出新的dfs序:

  • 这样我们修改2到9试试:
    • 修改1到9(这里是dfs序)
    • 修改2到5
  • 非常快捷!!!!根据上面的推论,剖分的次数一定不超过\(\log N\)
  • 所以结合数据结构(比如线段树),单次修改路径的时间复杂度为\(O(\log^2N)\)
  • 修改子树不需要剖分,所以是\(O(\log N)\)

剖分?

  • 那么要怎样才能把路径剖成一段一段呢

  • 如果两个点在同一条重链上,那么直接修改即可,因为他们的编号连续

    • 比如2到10,因为在同一条重链上,直接修改区间\([dfn[2],dfn[10]]\)
  • 那么不在同一条重链上怎么办?比如10到12

  • 那好像只能一层层往上跳,像LCA那样,把经过的点修改一下就行了

  • 显然每次能跳重链还跳什么重边啊,所以一个点如果它的链头没有超过它的LCA,那么就可以直接修改这条链,然后节点往上跳就行了
  • 那么……选哪个点往上跳……怎么不超过LCA啊,LCA好像跟深度有关系……
  • 所以我们每次选取所在重链链头深度较深的往上跳
    • 为什么?两个节点往上跳,必定相交于LCA,而LCA只会在一条重链上,显然要么在那两个点的重链上,要么就是其他点的重链上
    • 同时对于那两个子节点,一定不会是链头较深的那个。因为链头较深,显然不会超过LCA,如果超过,那另一个点的重链也经过了LCA,产生重复
    • 所以我们选取所在重链链头深度较深的节点开始跳,一定不会超过LCA
  • 最后跳到同一条重链上,直接修改即可
  • 查询同理
while(top[x]!=top[y]) {//不在同一条重链
    if(d[top[x]]<d[top[y]]) std::swap(x,y);//令top[x]的深度更深
    update(z,dfn[top[x]],dfn[x],1,1,N);//修改一下
    x = fa[top[x]];//跳到它的父节点去,也就是跳出这一条重链
}
if(d[x]>d[y]) std::swap(x,y);
update(z,dfn[x],dfn[y],1,1,N);

代码思路

  • 一遍dfs求出重儿子、深度、子树大小、父节点
  • 再一遍dfs求出dfs序和所在重链的链头
  • 通过线段树或其他数据结构维护信息

dfs1

void dfs1(int u,int father) {
    fa[u] = father;//记录父节点
    d[u] = d[father] + 1;//记录深度
    size[u] = 1;son[u] = 0;//子树大小先设为1(本身),设不存在重儿子
    for(int i=head[u];i;i=G[i].next) {//遍历所有出边
        int v = G[i].v;if(v==father) continue;//不能是父节点
        dfs1(v,u);//递归
        size[u] += size[v];//加上子树大小
        if(size[v]>size[son[u]]) son[u] = v;//判断是否是当前最重的儿子
    }
}

dfs2

void dfs2(int u,int tp) {
    dfn[u] = ++num;//记录dfs序
    top[u] = tp;//记录链头
    val[num] = w[u];//dfs序作为其在新数组上的位置
    if(son[u]!=0) dfs2(son[u],tp);//优先递归重儿子
    for(int i=head[u];i;i=G[i].next) {//访问出边
        int v = G[i].v;if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);//每个轻儿子都有以自己为链头的重链
    }
}

代码

在这里,我们定义:

\(son[u]\)为重儿子

\(size[u]\)为子树大小

\(fa[u]\)为父节点

\(d[u]\)为深度

\(dfn[u]\)为dfs序

\(top[u]\)为所在重链的链头

#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 100005

struct edge {
    int v,next;
}G[MAXN<<1];
int head[MAXN];
int seg[MAXN<<2],tag[MAXN<<2];

int son[MAXN],size[MAXN],fa[MAXN],d[MAXN];
int dfn[MAXN],top[MAXN];
int w[MAXN],val[MAXN];

int N,M,R,P,tot = 0,num = 0;

inline void add(int u,int v) {
    G[++tot].v = v;G[tot].next = head[u];head[u] = tot;
}

#define mid ((l+r)>>1)
#define lson (rt<<1)
#define rson (rt<<1|1)
#define sizel (((l+r)>>1)-l+1)
#define sizer (r-((l+r)>>1))

inline void pushup(int rt) {
    seg[rt] = (seg[lson] + seg[rson])%P; 
}

void Build(int rt,int l,int r) {
    tag[rt] = 0;
    if(l==r) {
        seg[rt] = val[l];
        return;
    }
    Build(lson,l,mid);
    Build(rson,mid+1,r);
    pushup(rt);
}

void pushdown(int rt,int l,int r) {
    tag[lson] = (tag[lson]+tag[rt])%P;
    tag[rson] = (tag[rson]+tag[rt])%P;
    seg[lson] = (seg[lson]+tag[rt]*sizel)%P;
    seg[rson] = (seg[rson]+tag[rt]*sizer)%P;
    tag[rt] = 0;
}

void update(int C,int L,int R,int rt,int l,int r) {
    if(L>r||R<l) return;
    if(L<=l&&R>=r) {
        seg[rt] = (seg[rt]+C*(r-l+1))%P;
        tag[rt] = (tag[rt]+C)%P;
        return;
    }
    pushdown(rt,l,r);
    if(L<=mid) update(C,L,R,lson,l,mid);
    if(R>mid) update(C,L,R,rson,mid+1,r);
    pushup(rt);
}
int query(int L,int R,int rt,int l,int r) {
    if(L>r||R<l) return 0;
    if(L<=l&&R>=r) return seg[rt];
    pushdown(rt,l,r);
    int ans = 0;
    if(L<=mid) ans = query(L,R,lson,l,mid);
    if(R>mid) ans += query(L,R,rson,mid+1,r);
    pushup(rt);
    return ans%P;
}

void dfs1(int u,int father) {
    fa[u] = father;d[u] = d[father] + 1;
    size[u] = 1;son[u] = 0;
    for(int i=head[u];i;i=G[i].next) {
        int v = G[i].v;if(v==father) continue;
        dfs1(v,u);
        size[u] += size[v];
        if(size[v]>size[son[u]]) son[u] = v;
    }
}

void dfs2(int u,int tp) {
    dfn[u] = ++num;
    top[u] = tp;
    val[num] = w[u];
    if(son[u]!=0) dfs2(son[u],tp);
    for(int i=head[u];i;i=G[i].next) {
        int v = G[i].v;if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}

void chain_update(int x,int y,int z) {
    while(top[x]!=top[y]) {
        if(d[top[x]]<d[top[y]]) std::swap(x,y);
        update(z,dfn[top[x]],dfn[x],1,1,N);
        x = fa[top[x]];
    }
    if(d[x]>d[y]) std::swap(x,y);
    update(z,dfn[x],dfn[y],1,1,N);
}

int chain_query(int x,int y) {
    int ans = 0;
    while(top[x]!=top[y]) {
        if(d[top[x]]<d[top[y]]) std::swap(x,y);
        ans += query(dfn[top[x]],dfn[x],1,1,N);
        ans%=P;
        x = fa[top[x]];
    }
    if(d[x]>d[y]) std::swap(x,y);
    ans += query(dfn[x],dfn[y],1,1,N);
    return ans%P;
}

int main() {

    scanf("%d%d%d%d",&N,&M,&R,&P);
    for(int i=1;i<=N;++i) {
        scanf("%d",&w[i]);
    }

    int u,v;
    std::memset(head,0,sizeof(head));
    for(int i=1;i<N;++i) {
        scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }

    d[R] = 0;
    size[0] = 0;
    dfs1(R,R);
    dfs2(R,R);
    Build(1,1,N);

    int opt,x,y,z;
    for(int i=1;i<=M;++i) {
        scanf("%d",&opt);
        if(opt==1) {
            scanf("%d%d%d",&x,&y,&z);
            chain_update(x,y,z);
        }
        else if(opt==2) {
            scanf("%d%d",&x,&y);
            printf("%d\n",chain_query(x,y));
        }
        else if(opt==3) {
            scanf("%d%d",&x,&z);
            update(z,dfn[x],dfn[x]+size[x]-1,1,1,N);
        }
        else {
            scanf("%d",&x);
            printf("%d\n",query(dfn[x],dfn[x]+size[x]-1,1,1,N));
        }
    }

    return 0;
}

推荐阅读