首页 > 技术文章 > BZOJ - 3730 震波 (点分树/边分树+树状数组)

asdfsag 2020-03-18 12:45 原文

两种操作:

1.查询与树上结点x距离不超过k的结点权值之和

2.将结点x的权值修改为y

点分树模板题。

首先考虑一种比较暴力的做法:用树形dp的思想,将树转化成有根树,设f[u][k]为结点u子树下与其距离不超过k的点权和,则ans(x,k)=f[u][k]+f[fa[u]][k-1]-f[u][k-2]+f[fa[fa[u]]][k-2]-f[fa[u]][k-3]...,可如果树太高的话就GG了,要是能把树高变成logn级别的就好了。

建立点分树(点分治过程中形成的树),每个结点维护两个树状数组,一个维护以该结点为重心的各个距离上的权值和,一个维护其虚父亲(点分树上的父亲)在“该结点方向”上的各个距离上的权值和(用于容斥)。还要维护点分树上每个结点到其所有祖先结点在原树上的距离,修改和询问就从当前结点开始不断向上跳即可。(点分树的树高是logn级别的,可以暴力往上跳)

 总复杂度$O(nlogn+qlog^2n)$

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=1e5+10,inf=0x3f3f3f3f;
 5 struct E {int v,nxt;} e[N<<1];
 6 int a[N],hd[N],ne,n,Q,siz[N],mx[N],vis[N],RT,tot,cnt[N],fa[N],rnk[N],ds[N][20];
 7 int buf[N*40],*ptr=buf;
 8 struct BIT {
 9     int *c,n;
10     int lb(int x) {return x&-x;}
11     void add(int u,int x) {for(; u<=n; u+=lb(u))c[u]+=x;}
12     int get(int u) {int ret=0; for(u=min(u,n); u; u-=lb(u))ret+=c[u]; return ret;}
13     void build(int* a) {
14         for(int i=1; i<=n; ++i)c[i]=a[i];
15         for(int i=1; i<=n; ++i)if(i+lb(i)<=n)c[i+lb(i)]+=c[i];
16     }
17 } c[N][2];
18 BIT newBIT(int n) {BIT t= {ptr,n}; ptr+=n+1; return t;}
19 void link(int u,int v) {e[ne]= (E) {v,hd[u]},hd[u]=ne++;}
20 void getrt(int u,int f) {
21     siz[u]=1,mx[u]=0;
22     for(int i=hd[u]; ~i; i=e[i].nxt) {
23         int v=e[i].v;
24         if(vis[v]||v==f)continue;
25         getrt(v,u),siz[u]+=siz[v],mx[u]=max(mx[u],siz[v]);
26     }
27     mx[u]=max(mx[u],tot-siz[u]);
28     if(mx[u]<mx[RT])RT=u;
29 }
30 void getdis(int u,int f,int d,int rk) {
31     ds[u][rk]=d;
32     cnt[d+1]+=a[u];
33     for(int i=hd[u]; ~i; i=e[i].nxt) {
34         int v=e[i].v;
35         if(vis[v]||v==f)continue;
36         getdis(v,u,d+1,rk);
37     }
38 }
39 void cal(int u,int f) {
40     int m=1;
41     for(; cnt[m]; ++m);
42     --m;
43     c[u][f]=newBIT(m);
44     c[u][f].build(cnt);
45     for(int i=1; i<=m; ++i)cnt[i]=0;
46 }
47 void solve(int u,int rk) {
48     rnk[u]=rk;
49     getdis(u,0,0,rk),cal(u,0),vis[u]=1;
50     for(int i=hd[u]; ~i; i=e[i].nxt) {
51         int v=e[i].v;
52         if(vis[v])continue;
53         getdis(v,0,0,rk+1),RT=0,tot=siz[v];
54         getrt(v,0),cal(RT,1),fa[RT]=u,solve(RT,rk+1);
55     }
56 }
57 void upd(int u,int x) {
58     x-=a[u],a[u]+=x;
59     int dis=0;
60     c[u][0].add(dis+1,x);
61     for(int v=u; fa[v]; v=fa[v]) {
62         dis=ds[u][rnk[fa[v]]];
63         c[fa[v]][0].add(dis+1,x),c[v][1].add(dis,x);
64     }
65 }
66 int qry(int u,int d) {
67     int ret=0,dis=d;
68     ret+=c[u][0].get(dis+1);
69     for(int v=u; fa[v]; v=fa[v]) {
70         dis=d-ds[u][rnk[fa[v]]];
71         if(dis>=0)ret+=c[fa[v]][0].get(dis+1)-c[v][1].get(dis);
72     }
73     return ret;
74 }
75 int main() {
76     memset(hd,-1,sizeof hd),ne=0;
77     scanf("%d%d",&n,&Q);
78     for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
79     for(int i=1; i<n; ++i) {
80         int u,v;
81         scanf("%d%d",&u,&v);
82         link(u,v);
83         link(v,u);
84     }
85     mx[0]=inf,RT=0,tot=n,getrt(1,0),fa[RT]=0,solve(RT,0);
86     for(int la=0; Q--;) {
87         int f,u,x;
88         scanf("%d%d%d",&f,&u,&x);
89         u^=la,x^=la;
90         if(f==0)printf("%d\n",la=qry(u,x));
91         else upd(u,x);
92     }
93     return 0;
94 }

 还有一种做法是边分树,需要将原树转化成二叉树,稍微麻烦点但省去了容斥的过程,感觉相对更容易理解。

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef long long ll;
  4 const int N=4e5+10,inf=0x3f3f3f3f;
  5 struct E {int v,c,nxt;} e[N];
  6 int a[N],hd[N],ne,n,Q,siz[N],vis[N],cnt[N],fa[N],tot,rt,mx,edge[N],mxd,rnk[N],ds[N][20];
  7 int buf[N*40],*ptr=buf;
  8 struct BIT {
  9     int *c,n;
 10     int lb(int x) {return x&-x;}
 11     void add(int u,int x) {for(; u<=n; u+=lb(u))c[u]+=x;}
 12     int get(int u) {int ret=0; for(u=min(u,n); u; u-=lb(u))ret+=c[u]; return ret;}
 13     void build(int* a) {
 14         for(int i=1; i<=n; ++i)c[i]=a[i];
 15         for(int i=1; i<=n; ++i)if(i+lb(i)<=n)c[i+lb(i)]+=c[i];
 16     }
 17 } c[N];
 18 BIT newBIT(int n) {BIT t= {ptr,n}; ptr+=n+1; return t;}
 19 void link(int u,int v,int c) {e[ne]= (E) {v,c,hd[u]},hd[u]=ne++;}
 20 vector<E> g[N];
 21 void rebuild(int u,int f) {
 22     int w=u;
 23     for(int i=hd[u]; ~i; i=e[i].nxt) {
 24         int v=e[i].v;
 25         if(v==f)continue;
 26         rebuild(v,u);
 27         g[w].push_back({v,1,0});
 28         if(~e[i].nxt&&~e[e[i].nxt].nxt)g[w].push_back({++tot,0,0}),w=tot;
 29     }
 30 }
 31 void rebuild() {
 32     tot=n,rebuild(1,0);
 33     memset(hd,-1,sizeof hd),ne=0;
 34     for(int i=1; i<=tot; ++i) {
 35         for(int j=0; j<g[i].size(); ++j)
 36             link(i,g[i][j].v,g[i][j].c),link(g[i][j].v,i,g[i][j].c);
 37         g[i].clear();
 38     }
 39 }
 40 void getrt(int u,int f,int sz) {
 41     siz[u]=1;
 42     for(int i=hd[u]; ~i; i=e[i].nxt) {
 43         int v=e[i].v;
 44         if(vis[i]||v==f)continue;
 45         getrt(v,u,sz),siz[u]+=siz[v];
 46         int t=max(siz[v],sz-siz[v]);
 47         if(t<mx)mx=t,rt=i;
 48     }
 49 }
 50 void getdis(int u,int f,int d,int rk) {
 51     ds[u][rk]=d;
 52     if(u<=n)cnt[d+1]+=a[u],mxd=max(mxd,d+1);
 53     for(int i=hd[u]; ~i; i=e[i].nxt) {
 54         int v=e[i].v,c=e[i].c;
 55         if(vis[i]||v==f)continue;
 56         getdis(v,u,d+c,rk);
 57     }
 58 }
 59 void cal(int u) {
 60     c[u]=newBIT(mxd);
 61     c[u].build(cnt);
 62     for(int i=1; i<=mxd; ++i)cnt[i]=0;
 63 }
 64 void solve(int u,int f,int sz,int rk) {
 65     if(sz==1) {edge[u]=f; return;}
 66     mx=inf,getrt(u,0,sz);
 67     int t=rt;
 68     vis[t]=vis[t^1]=1;
 69     fa[t]=fa[t^1]=f;
 70     rnk[t]=rnk[t^1]=rk;
 71     mxd=0,getdis(e[t].v,0,0,rk),cal(t);
 72     mxd=0,getdis(e[t^1].v,0,0,rk),cal(t^1);
 73     int a=siz[e[t].v],b=sz-siz[e[t].v];
 74     solve(e[t].v,rt,a,rk+1),solve(e[t^1].v,t^1,b,rk+1);
 75 }
 76 void upd(int u,int x) {
 77     x-=a[u],a[u]+=x;
 78     for(int t=edge[u]; ~t; t=fa[t]) {
 79         int dis=ds[u][rnk[t]];
 80         c[t].add(dis+1,x);
 81     }
 82 }
 83 int qry(int u,int d) {
 84     int ret=a[u];
 85     for(int t=edge[u]; ~t; t=fa[t]) {
 86         int dis=d-(ds[u][rnk[t]]+e[t].c);
 87         if(dis>=0)ret+=c[t^1].get(dis+1);
 88     }
 89     return ret;
 90 }
 91 int main() {
 92     memset(hd,-1,sizeof hd),ne=0;
 93     scanf("%d%d",&n,&Q);
 94     for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
 95     for(int i=1; i<n; ++i) {
 96         int u,v;
 97         scanf("%d%d",&u,&v);
 98         link(u,v,0);
 99         link(v,u,0);
100     }
101     rebuild();
102     solve(1,-1,tot,0);
103     for(int la=0; Q--;) {
104         int f,u,x;
105         scanf("%d%d%d",&f,&u,&x);
106         u^=la,x^=la;
107         if(f==0)printf("%d\n",la=qry(u,x));
108         else upd(u,x);
109     }
110     return 0;
111 }

 

推荐阅读