首页 > 技术文章 > ZOJ-2112-Dynamic Rankings(线段树套splay树)

arbitrary 2013-10-09 18:56 原文

题意:

完成两个操作:

1.询问一个区间里第k小的数;

2.修改数列中一个数的值。

分析:

线段树套平衡树,线段树中的每个节点都有一棵平衡树,维护线段树所记录的这个区间的元素。
这样处理空间上是O(nlogn)的,因为线段树有logn层,每层的平衡树所记的节点总数都有n个。
修改很容易想到,把所有包含要修改点的区间的平衡树都修改了就行了

查询使用二分答案的方法

// File Name: 2112.cpp
// Author: Zlbing
// Created Time: 2013年10月07日 星期一 18时24分39秒
#include<iostream>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<cstdio>
#include<set>
#include<map>
#include<vector>
#include<cstring>
#include<stack>
#include<cmath>
#include<queue>
using namespace std;
#define CL(x,v); memset(x,v,sizeof(x));
#define INF 0x3f3f3f3f
#define LL long long
#define REP(i,r,n) for(int i=r;i<=n;i++)
#define RREP(i,n,r) for(int i=n;i>=r;i--)
#define lson l,m,root<<1
#define rson m+1,r,root<<1|1
const int MAXN=5e4+100;
//线段树套平衡树,线段树中的每个结点都有一颗平衡树,维护线段树所记录
//的这个区间的元素
//rt数组是维护线段树的数组,表示线段树结点中的平衡树的根
//对于spaly树中的修改只要将rt修改成rt[root]并传递root参数
struct SplayTree {
    int sz[MAXN*20];
    int ch[MAXN*20][2];
    int pre[MAXN*20];
    int top;
    int rt[MAXN<<2];
    inline void up(int x){
        sz[x]  = cnt[x]  + sz[ ch[x][0] ] + sz[ ch[x][1] ];
    }
    inline void Rotate(int x,int f){
        int y=pre[x];
        ch[y][!f] = ch[x][f];
        pre[ ch[x][f] ] = y;
        pre[x] = pre[y];
        if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] =x;
        ch[x][f] = y;
        pre[y] = x;
        up(y);
    }
    inline void Splay(int x,int goal,int root){//将x旋转到goal的下面
        while(pre[x] != goal){
            if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][0] == x);
            else   {
                int y=pre[x],z=pre[y];
                int f = (ch[z][0]==y);
                if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f);
                else Rotate(y,f),Rotate(x,f);
            }
        }
        up(x);
        if(goal==0) rt[root]=x;
    }
    inline void RTO(int k,int goal,int root){//将第k位数旋转到goal的下面
        int x=rt[root];
        while(sz[ ch[x][0] ] != k-1) {
            if(k < sz[ ch[x][0] ]+1) x=ch[x][0];
            else {
                k-=(sz[ ch[x][0] ]+1);
                x = ch[x][1];
            }
        }
        Splay(x,goal,root);
    }
    inline void vist(int x){
        if(x){
            printf("结点%2d : 左儿子  %2d   右儿子  %2d   %2d sz=%d\n",x,ch[x][0],ch[x][1],val[x],sz[x]);
            vist(ch[x][0]);
            vist(ch[x][1]);
        }
    }
    inline void Newnode(int &x,int c){
        x=++top;
        ch[x][0] = ch[x][1] = pre[x] = 0;
        sz[x]=1; cnt[x]=1;
        val[x] = c;
    }
    inline void init(){
        top=0;
    }
    inline void Insert(int &x,int key,int f,int root){
        if(!x) {
            Newnode(x,key);
            pre[x]=f;
            Splay(x,0,root);
            return ;
        }
        if(key==val[x]){
            cnt[x]++;
            sz[x]++;
            return ;
        }else if(key<val[x]) {
            Insert(ch[x][0],key,x,root);
        } else {
            Insert(ch[x][1],key,x,root);
        }
        up(x);
    }

    void Del(int root){  //删除根结点
         if(cnt[rt[root]]>1)
         {
            cnt[rt[root]]--;
         }
         else
         {
             int t=rt[root];
             if(ch[rt[root]][1]) {
                 rt[root]=ch[rt[root]][1];
                 RTO(1,0,root);
                 ch[rt[root]][0]=ch[t][0];
                 if(ch[rt[root]][0]) pre[ch[rt[root]][0]]=rt[root];
             }
             else rt[root]=ch[rt[root]][0];
             pre[rt[root]]=0;
         }
         up(rt[root]);
    }
    void findpre(int x,int key,int &ans){  //找key前趋
        if(!x)  return ;
        if(val[x] <= key){
            ans=x;
            findpre(ch[x][1],key,ans);
        } else
            findpre(ch[x][0],key,ans);
    }
    void findsucc(int x,int key,int &ans){  //找key后继
        if(!x) return ;
        if(val[x]>=key) {
            ans=x;
            findsucc(ch[x][0],key,ans);
        } else
            findsucc(ch[x][1],key,ans);
    }
    void findkey(int x,int key,int &ans)//找key
    {
        if(!x)return;
        if(val[x]==key)
            ans=x;
        else if(val[x]>key)
            findkey(ch[x][0],key,ans);
        else
            findkey(ch[x][1],key,ans);
    }
    //找第K大数
    inline int find_kth(int x,int k,int root){
        if(k<sz[ch[x][0]]+1) {
            return find_kth(ch[x][0],k,root);
        }else if(k > sz[ ch[x][0] ] + cnt[x] )
            return find_kth(ch[x][1],k-sz[ch[x][0]]-cnt[x],root);
        else{
            Splay(x,0,root);
            return val[x];
        }
    }
    int cnt[MAXN*20];
    int val[MAXN*20];
//---------------------------------------------
//建立线段树和线段树中的每个结点的平衡树
    void build(int l,int r,int root)
    {
        rt[root]=0;
        for(int i=l;i<=r;i++)
            Insert(rt[root],a[i],0,root);
        if(l>=r)return;
        int m=(l+r)>>1;
        build(lson);
        build(rson);
    }
    void update(int l,int r,int root,int i,int x)
    {
        int ans=0;
        findkey(rt[root],a[i],ans);
        Splay(ans,0,root);
        Del(root);
        Insert(rt[root],x,0,root);

        if(l>=r)return;
        int m=(l+r)>>1;
        if(i<=m)update(lson,i,x);
        else update(rson,i,x);
    }
    int cntLess(int x,int key)
    {
        int ret=0;
        while(x)
        {
            if(val[x]>key)
                x=ch[x][0];
            else
            {
                ret+=cnt[x]+sz[ch[x][0]];
                x=ch[x][1];
            }
        }
        return ret;
    }
    int getnumLess(int l,int r,int root,int L,int R,int x)
    {
        if(L<=l&&R>=r)
            return cntLess(rt[root],x);
        int m=(l+r)>>1;
        int ret=0;
        if(L<=m)ret+=getnumLess(lson,L,R,x);
        if(R>m)ret+=getnumLess(rson,L,R,x);
        return ret;
    }
    int search(int L,int R,int k)
    {
        int l=0,r=INF;
        int ans=0;
        while(l<=r)
        {
            int m=(l+r)>>1;
            int cnt=getnumLess(1,n,1,L,R,m);
            if(cnt>=k)
            {
                r=m-1;
                ans=m;
            }
            else l=m+1;
        }
        return ans;
    }
    void solve()
    {
        scanf("%d%d",&n,&m);
        REP(i,1,n)
            scanf("%d",&a[i]);
        build(1,n,1);
        REP(i,1,m)
        {
            int x,y,z;
            char str[4];
            scanf("%s",str);
            if(str[0]=='C')
            {
                scanf("%d%d",&x,&y);
                update(1,n,1,x,y);
                a[x]=y;
            }
            else
            {
                scanf("%d%d%d",&x,&y,&z);
                int k=search(x,y,z);
                printf("%d\n",k);
            }
        }
    }
    int a[MAXN];
    int n,m;

}spt;
int main()
{
    int cas;
    scanf("%d",&cas);
    while(cas--)
    {
        spt.init();
        spt.solve();
    }
    return 0;
}

 

推荐阅读