首页 > 技术文章 > 树链剖分 「笔记」

JHTBlog 2020-10-03 23:40 原文


考虑这个问题:

要求在树上进行两种操作

1.把 \(x\)\(y\) 的简单路径上所有点的权值加 \(k\)

2.询问 \(x\)\(y\) 的简单路径上的点权和

(任意单个问题都可以通过树上差分或树上前缀和维护,合起来不就是树上线段树?)

1.轻重链剖分


类似在数列中维护区间和,树链剖分把树剖分成多条链,可以用各种数据结构维护链。

模板题:轻重链剖分

做以下定义:

重儿子:子树大小最大的儿子

重边:连重儿子的边

重链:重边连起来的链

轻儿子、边、链:其它的儿子、边、链

c83d70cf3bc79f3d3adc2d8cb9a1cd11728b2949.jpg

用两个DFS预处理

DFS序标号节点,DFS时先走重儿子,保证每条重链的 \(dfn\) 连续。

然后用DFS序建立线段树。

\(tim\) 时间戳,\(fa\) 父亲节点编号,\(dep\) 深度,\(siz\) 子树大小,\(son\) 重儿子编号,

\(dfn\) DFS序,\(x\_u\) DFS序向节点编号的映射,\(top\) 链顶节点,轻儿子的链顶是它自己

int tim;
int fa[N], dep[N], siz[N], son[N], dfn[N], x_u[N], top[N];
void Dfs1(int x, int ff) { // 处理fa, dep, siz, son
	fa[x] = ff; dep[x] = dep[ff] + 1; siz[x]++;
	for(int i = head[x]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to != ff) {
			Dfs1(to, x);
			siz[x] += siz[to];
			if(siz[to] > siz[son[x]]) son[x] = to;
		}
	}
}
void Dfs2(int x, int ff) { // 处理dfn, x_u, top. ff为链顶节点
	dfn[x] = ++tim, x_u[dfn[x]] = x;
	top[x] = ff;
	if(!son[x]) return ;
	Dfs2(son[x], ff); // 先走重儿子
	for(int i = head[x]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(!dfn[to]) Dfs2(to, to);
	}
}

线段树部分

显然子树内DFS序是个连续的区间,3、4操作可以简单做。

// 最普通的线段树
struct Tree {
	int l, r;
	long long sum, lazy;
} tree[N << 2];
void Update(int x) { 
	tree[x].sum = (tree[x << 1].sum + tree[x << 1 | 1].sum) % mod;
}
void Pushdown(int x) {
	long long k = tree[x].lazy;
	k %= mod;
	tree[x].lazy = 0;
	tree[x << 1].lazy += k, tree[x << 1 | 1].lazy += k;
	tree[x << 1].lazy %= mod, tree[x << 1 | 1].lazy %= mod;
	tree[x << 1].sum += k * (tree[x << 1].r - tree[x << 1].l + 1);
	tree[x << 1].sum %= mod;
	tree[x << 1 | 1].sum += k * (tree[x << 1 | 1].r - tree[x << 1 | 1].l + 1);
	tree[x << 1 | 1].sum %= mod;
}
void Build(int x, int l, int r) {
	tree[x].l = l, tree[x].r = r;
	if(l == r) {
		tree[x].sum = a[x_u[l]] % mod; // 用dfn建树
		return ;
	}
	int mid = (l + r) >> 1;
	Build(x << 1, l, mid), Build(x << 1 | 1, mid + 1, r);
	Update(x);
}
void Addsum(int x, int l, int r, long long k) {
	if(tree[x].l >= l && tree[x].r <= r) {
		tree[x].sum += k * (tree[x].r - tree[x].l + 1) % mod;
		tree[x].sum %= mod;
		tree[x].lazy = (tree[x].lazy + k) % mod;
		return ;
	}
	if(tree[x].lazy) Pushdown(x);
	int mid = (tree[x].l + tree[x].r) >> 1;
	if(l <= mid) Addsum(x << 1, l, r, k);
	if(r > mid) Addsum(x << 1 | 1, l, r, k);
	Update(x);
}
long long Reqsum(int x, int l, int r) {
	if(tree[x].l >= l && tree[x].r <= r) return tree[x].sum % mod;
	long long res = 0;
	if(tree[x].lazy) Pushdown(x);
	int mid = (tree[x].l + tree[x].r) >> 1;
	if(l <= mid) res += Reqsum(x << 1, l, r);
	if(r > mid) res += Reqsum(x << 1 | 1, l, r);
	return res % mod;
}

考虑1、2操作怎么做

\(x,y\) 沿着重链往上跳,直到它们跳到同一条重链上,最后考虑 \(x, y\) 之间的这一段。

具体就是当 \(top_x \neq top_y\) 时,将 \(dep_{top}\) 较大的跳到 \(fa_{top}\) 并处理经过的这条重链,因为每条重链的DFS序是连续的,所以可以在线段树上区间操作。

void Pathadd(int x, int y, long long k) {
	while(top[x] != top[y]) {
		if(dep[top[x]] < dep[top[y]]) swap(x, y); // 跳top深度大的
		Addsum(1, dfn[top[x]], dfn[x], k); // 将这条重链加上k
		x = fa[top[x]];
	}
	if(dfn[x] > dfn[y]) swap(x, y);
	Addsum(1, dfn[x], dfn[y], k);
}
long long Pathsum(int x, int y) {
	long long res = 0;
	while(top[x] != top[y]) {
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		res += Reqsum(1, dfn[top[x]], dfn[x]);
        	res %= mod;
        	x = fa[top[x]];
	}
    	if(dfn[x] > dfn[y]) swap(x, y);
    	res += Reqsum(1, dfn[x], dfn[y]);
    	return res % mod;
}

那么这题就可以做了……

点击展开代码
#include "iostream"
#include "cstdio"
#include "algorithm"
using namespace std;
const int N = 1e5 + 5, M = 1e5 + 5;
int n, m, root, mod;
long long a[N];
int head[N], cnt;
struct Edge{
	int to, nxt;
} edge[M << 1];
void Add(int from, int to) {
	edge[++cnt].to = to; edge[cnt].nxt = head[from];
	head[from] = cnt;
}
int tim;
int fa[N], dep[N], siz[N], son[N], dfn[N], x_u[N], top[N];
void Dfs1(int x, int ff) {
	fa[x] = ff; dep[x] = dep[ff] + 1; siz[x]++;
	for(int i = head[x]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to != ff) {
			Dfs1(to, x);
			siz[x] += siz[to];
			if(siz[to] > siz[son[x]]) son[x] = to;
		}
	}
}
void Dfs2(int x, int ff) {
	dfn[x] = ++tim, x_u[dfn[x]] = x;
	top[x] = ff;
	if(!son[x]) return ;
	Dfs2(son[x], ff);
	for(int i = head[x]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(!dfn[to]) Dfs2(to, to);
	}
}
struct Tree {
	int l, r;
	long long sum, lazy;
} tree[N << 2];
void Update(int x) { 
	tree[x].sum = (tree[x << 1].sum + tree[x << 1 | 1].sum) % mod;
}
void Pushdown(int x) {
	long long k = tree[x].lazy;
	k %= mod;
	tree[x].lazy = 0;
	tree[x << 1].lazy += k, tree[x << 1 | 1].lazy += k;
	tree[x << 1].lazy %= mod, tree[x << 1 | 1].lazy %= mod;
	tree[x << 1].sum += k * (tree[x << 1].r - tree[x << 1].l + 1);
	tree[x << 1].sum %= mod;
	tree[x << 1 | 1].sum += k * (tree[x << 1 | 1].r - tree[x << 1 | 1].l + 1);
	tree[x << 1 | 1].sum %= mod;
}
void Build(int x, int l, int r) {
	tree[x].l = l, tree[x].r = r;
	if(l == r) {
		tree[x].sum = a[x_u[l]] % mod;
		return ;
	}
	int mid = (l + r) >> 1;
	Build(x << 1, l, mid), Build(x << 1 | 1, mid + 1, r);
	Update(x);
}
void Addsum(int x, int l, int r, long long k) {
	if(tree[x].l >= l && tree[x].r <= r) {
		tree[x].sum += k * (tree[x].r - tree[x].l + 1) % mod;
		tree[x].sum %= mod;
		tree[x].lazy = (tree[x].lazy + k) % mod;
		return ;
	}
	if(tree[x].lazy) Pushdown(x);
	int mid = (tree[x].l + tree[x].r) >> 1;
	if(l <= mid) Addsum(x << 1, l, r, k);
	if(r > mid) Addsum(x << 1 | 1, l, r, k);
	Update(x);
}
long long Reqsum(int x, int l, int r) {
	if(tree[x].l >= l && tree[x].r <= r) return tree[x].sum % mod;
	long long ans = 0;
	if(tree[x].lazy) Pushdown(x);
	int mid = (tree[x].l + tree[x].r) >> 1;
	if(l <= mid) ans += Reqsum(x << 1, l, r);
	if(r > mid) ans += Reqsum(x << 1 | 1, l, r);
	return ans % mod;
}
void Pathadd(int x, int y, long long k) {
	while(top[x] != top[y]) {
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		Addsum(1, dfn[top[x]], dfn[x], k);
		x = fa[top[x]];
	}
	if(dfn[x] > dfn[y]) swap(x, y);
	Addsum(1, dfn[x], dfn[y], k);
}
long long Pathsum(int x, int y) {
	long long res = 0;
	while(top[x] != top[y]) {
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		res += Reqsum(1, dfn[top[x]], dfn[x]);
        	res %= mod;
        	x = fa[top[x]];
	}
    	if(dfn[x] > dfn[y]) swap(x, y);
    	res += Reqsum(1, dfn[x], dfn[y]);
    	return res % mod;
}
int main() {
	scanf("%d%d%d%d", &n, &m, &root, &mod);
	for(int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
	for(int i = 1; i < n; ++i) {
		int x, y;
		scanf("%d%d", &x, &y);
		Add(x, y), Add(y, x);
	}
	Dfs1(root, 0); Dfs2(root, root);
	Build(1, 1, n);
	while(m--) {
		int o; scanf("%d", &o) ;
		if(o == 1) {
			int x, y; long long z;
			scanf("%d%d%lld", &x, &y, &z);
			Pathadd(x, y, z);
		}
		if(o == 2) {
			int x, y;
			scanf("%d%d", &x, &y);
			printf("%lld\n", Pathsum(x, y));
		}
		if(o == 3) {
			int x; long long z;
			scanf("%d%lld", &x, &z);
			Addsum(1, dfn[x], dfn[x] + siz[x] - 1, z);
		}
		if(o == 4) {
			int x;
			scanf("%d", &x);
			printf("%lld\n", Reqsum(1, dfn[x], dfn[x] + siz[x] - 1));
		} 
	}
	return 0;
}

2.换根


Loj#139

推荐阅读