主席树入门
--by Hzu_Tested
主席树,又叫函数式线段树或有历史版本的线段树。可以把主席树看做线段树的一个拓展,其本身是一个线段树集合,通过找到共用节点,使维护这个线段树堆变得容易而且节省内存。
抛砖引玉
给出n个数,如何找到这n个数里的第k大的数。
首先,我们对这n个数进行处理,令buf[i]表示这n个数里i出现的次数。
以4,2,1,3,2为例,构建buf[]={0,1,2,1,1},即0出现0次(这里我们不用到0),1出现1次,2出现2次,3出现1次,4出现1次。
可以构造线段树:
每个节点表示对应范围内的值出现的次数。
假设k=3,求[1,5]第3大的数。
每次判断当前节点的右节点的值是否大于等于k,若大于等于k,则说明第k大的数在右节点(的子树)中,进入右节点;若不大于k,则说明第k大的数在左节点(的子树)中,且第k大变为第k-rson.val大,即k要减去右节点的值(前rson.val大的数都在右边了)。反复上述过程,直到到达叶节点,其对应的范围即时所求。
那么如果求[l,r]的第k大呢,…对[l,r]建立一个线段树,再求解。
但是这么做最坏要建立n2棵线段树,时间空间上都不能接受。
这里引入一个线段树差的概念,即两棵结构完全一样的线段树a和b,将b上所有节点和a上对应节点做差值计算,得到的线段树即b-a。
(类似前缀和的差,sum[i]表示a[1]到a[i]的累加和,那么sum[b]-sum[a]等于a[a-1,b]的累加和)。
那么我们可以建立线段树集tree[i]表示[1,i]范围的线段树,且保证所有线段树的结构相同。这样只需建立n棵线段树就可以表示所有[l,r]的线段树了,tree[l,r]=tree[r]-tree[l-1]。
举例,求[3,5]k=3,那么取tree[5]-tree[3-1]。
按照之前的方法求找到结果为1。
然后我们发现,每两棵线段树之间只有一条链路不同,即增加的数到根节点这条链路。就相当于线段树上这个值被更新了,其他节点是不变的。
那么除了被更新的链路,其他相同的节点就可以共用(这是主席树的中心思想)。我们以tree[0]到tree[1]为例。
绿色部分为新增的链路,橘色部分为共用的部分(这里为每个节点标记了序号)。
看到这里就知道什么是“有历史版本的线段树”了吧!其实就是对线段树的更新不直接在原节点上更新,而是新建一条链路,配合没有更新的节点,生产一棵新的线段树。
最终4,2,1,3,2的到的主席树如下。
例题:POJ – 2104
题意:
给出n个数字的序列a[1…n](ai的绝对值不超过1e9,且不重复),和m个查询[l,r,k],求a[l…r]中第k小的数字。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 using namespace std; 5 6 const int maxn = 100000+100; 7 int tree[maxn]; // 记录每棵线段树的根节点 8 9 struct Num { 10 int index,val,newval; 11 }num[maxn]; 12 13 bool cmp1(const Num& left,const Num& right) { 14 return left.val<right.val; 15 } 16 17 bool cmp2(const Num& left,const Num& right) { 18 return left.index<right.index; 19 } 20 21 struct Node { 22 int val,lson,rson; 23 }node[maxn<<5]; 24 int cnt; 25 26 int n,m; 27 28 int build(int l,int r) { 29 int root = cnt++; 30 node[root].val = 0; 31 int mid = (l+r)>>1; 32 if(l<r) { 33 node[root].lson = build(l,mid); 34 node[root].rson = build(mid+1,r); 35 } 36 return root; 37 } 38 39 int updata(int root,int now) { 40 int next = cnt++; 41 int res = next; 42 node[next].val = node[root].val+1; 43 int l = 1, r = n; 44 int mid; 45 while(l<r) { 46 mid = (l+r)>>1; 47 if(now<=mid) { 48 node[next].lson = cnt++; 49 node[next].rson = node[root].rson; 50 next = node[next].lson; 51 root = node[root].lson; 52 r = mid; 53 } else { 54 node[next].rson = cnt++; 55 node[next].lson = node[root].lson; 56 next = node[next].rson; 57 root = node[root].rson; 58 l = mid+1; 59 } 60 node[next].val = node[root].val+1; 61 } 62 return res; 63 } 64 65 int query(int ql,int qr,int k) { 66 int l = 1, r = n; 67 int mid; 68 while(l<r) { 69 mid = (l+r)>>1; 70 if(node[node[qr].lson].val-node[node[ql].lson].val >= k) { 71 ql = node[ql].lson; 72 qr = node[qr].lson; 73 r = mid; 74 } else { 75 k -= node[node[qr].lson].val-node[node[ql].lson].val; 76 ql = node[ql].rson; 77 qr = node[qr].rson; 78 l = mid+1; 79 } 80 } 81 return r; 82 } 83 84 int main() { 85 while(scanf("%d",&n)!=EOF) { 86 scanf("%d",&m); 87 for(int i=1;i<=n;++i) { 88 scanf("%d",&num[i].val); 89 num[i].index = i; 90 } 91 sort(num+1,num+1+n,cmp1); 92 for(int i=1;i<=n;++i) { 93 num[i].newval = i; //将序列改为1-n的序列 94 } 95 sort(num+1,num+1+n,cmp2); 96 // build 97 memset(node,0,sizeof(node)); 98 memset(tree,0,sizeof(tree)); 99 cnt = 0; 100 tree[0] = build(1,n); 101 for(int i=1;i<=n;++i) { 102 tree[i] = updata(tree[i-1],num[i].newval); 103 } 104 sort(num+1,num+1+n,cmp1); 105 int l,r,k; 106 while(m--) { 107 scanf("%d%d%d",&l,&r,&k); 108 // query 109 int q = query(tree[l-1],tree[r],k); 110 printf("%d\n",num[q].val); 111 } 112 } 113 return 0; 114 }