首页 > 技术文章 > 学习笔记——Splay

ZCETHAN 2021-07-30 21:16 原文

前言

前几天有幸听学长讲平衡树,想着好久没写博客了,记录一下。

简介

Splay,平衡树的一种,依靠每次将访问到的点旋到根来保持树的平衡。

并且,Splay 还可以高效解决序列翻转等操作。

实现

前提

以下代码是基于这样的定义的:

struct Tree{int ch[2],val,siz,fa;}nd[MAXN];//表示某一个节点
void pushup(int rt){nd[rt].siz=nd[nd[rt].ch[0]].siz+nd[nd[rt].ch[1]].siz+1;}//更新某个节点子树的大小
int chk(int rt){return rt==nd[nd[rt].fa].ch[1];}//返回 rt 是左孩子还是右孩子
void clear(int rt){
    nd[rt].ch[0]=nd[rt].ch[1]=0;
    nd[rt].cnt=nd[rt].siz=nd[rt].val=nd[rt].fa=0;
}//删除节点

旋转

首先,对于依赖旋转的平衡树,这个操作是十分重要的。我们可以通过图片来理解,如何将父子关系互换并且不违背二叉搜索树的性质。

旋转有两种,对于左儿子用右旋,右儿子用左旋,下面以右旋为例。


首先最关心的是二叉搜索树的性质,我们可以发现,原来有:

\[uls<u<urs<fa<fars \]

在旋转后仍然满足:

\[uls'<u'<urs'<fa'<fars' \]

因此旋转没有破坏二叉搜索树的美妙性质。那我们来看旋转的实现:

  1. \(u\) 变成 \(fa\) 的父亲的儿子,\(u\) 的父亲更改为 \(fa\) 的父亲;
  2. \(u\) 的右儿子变成 \(fa\) 的左儿子;
  3. 互换 \(u\)\(fa\) 的父子关系。

左旋也一样,左右儿子倒过来就可以,因此我们可以根据 \(u\) 初始是左儿子还是右儿子,将两种旋转合并。

void rot(int rt){
	int p=nd[rt].fa,g=nd[p].fa,d=chk(rt);
	nd[g].ch[chk(p)]=rt;nd[rt].fa=g;
	nd[p].ch[d]=nd[rt].ch[d^1];
	nd[nd[rt].ch[d^1]].fa=p;
	nd[rt].ch[d^1]=p;nd[p].fa=rt;
	pushup(p);pushup(rt);//注意最后的更新
}

由于旋转是平衡树的基本操作,所以这里就先这样,主要理解双旋对 Splay 的优化。

双旋

双旋关心节点 \(u\) 的父亲的儿子类型与 \(u\) 的关系。

对此,我们可以分成 \(3\) 种情况:

  1. 父亲是根节点。此时直接旋转就可以了。
  2. 如果父亲的儿子类型与 \(u\) 相同。那么就先旋父亲,再旋 \(u\)
  3. 如果不同。那么就把 \(u\) 旋转 \(2\) 次。

为什么要这么麻烦呢?前面说过,Splay 是把访问的节点旋到根来维护平衡的,那我直接一个一个旋不就好了?为什么要定义一个双旋呢?

很简单,来看一个例子:

如果我访问顺序是 \(5\to 4\to 3\cdots 1\to 5\to 4\cdots\),那么可以发现,如果只是用单旋,每次查找为 \(\mathcal{O}(n)\),而用单旋不能改变链的事实,所以总的复杂度会高达 \(\mathcal{O}(n^2)\),直接 GG。(不理解可以自己手动模拟一下,发现每次旋到根后,整体还是同样形状的链)

那如果采用双旋呢?

可以看一下第一次操作如果用双旋结果是什么(把 \(5\) 转到根上)。


可以看到改变了链的形式,使得高度《大大》降低。

所以采用双旋,可以有效规避在链的情况下出现时间爆炸的情况。

所以我们采用双旋来实现 Splay,而对于把一个节点旋到根的操作,我们称之 \(Splay\) 操作(((

void splay(int rt){
    while(nd[rt].fa!=gl){
        int p=nd[rt].fa,g=nd[p].fa;
        if(g==0) rot(rt);//case 1
        else if(chk(rt)==chk(p)) rot(p),rot(rt);//case 2
        else rot(rt),rot(rt);// case 3
    }root=rt;
}

别的操作

有了 \(Splay\) 操作,就基本完成了 Splay,接下来就是一些比较细节的,和别的平衡树异曲同工的操作了。只要记住,对于所有操作,我们只要对目标点操作完后 \(Splay\) 一下就可以了,非常舒服。

插入

void ins(int rt,int val,int f){
    if(!rt){
        rt=++tot;nd[rt].val=val;nd[rt].siz=1;nd[rt].cnt=1;
        nd[rt].fa=f;nd[f].ch[val>=nd[f].val]=rt;
        splay(rt,0);return;//这个 splay 表示把 rt 转到 0 的儿子,也就是根
        //后文中有实现方法
    }if(nd[rt].val==val){
    	nd[rt].cnt++;splay(rt,0);
    	return;
	}
    int d=(nd[rt].val<=val);
    ins(nd[rt].ch[d],val,rt);
}

查找第 k 大

int kth(int rt,int k){
    while(rt){
    	int lsiz=nd[nd[rt].ch[0]].siz;
    	if(k>=lsiz+1&&k<=lsiz+nd[rt].cnt){
    		splay(rt,0);return nd[rt].val;
		}else if(k<=lsiz) rt=nd[rt].ch[0];
		else k-=lsiz+nd[rt].cnt,rt=nd[rt].ch[1];
	}return nd[rt].val;
}

查找元素排名

注意元素可以不在序列中

int rank(int rt,int val){
	int ret=1;
	while(rt){
		if(nd[rt].val==val){
			ret+=nd[nd[rt].ch[0]].siz;
			splay(rt,0);return ret;
		}else if(val<nd[rt].val) rt=nd[rt].ch[0];
		else ret+=nd[rt].cnt+nd[nd[rt].ch[0]].siz,rt=nd[rt].ch[1];
	}return ret;
}

删除节点

void del(int rt,int x){
    if(nd[rt].val==x){splay(rt,0);
        if(nd[rt].cnt>1) nd[rt].cnt--;
        else if(nd[rt].ch[0]){
            int ls=nd[rt].ch[0];
            while(nd[ls].ch[1]) ls=nd[ls].ch[1];
            splay(ls,rt);nd[nd[rt].ch[1]].fa=ls;
            nd[ls].ch[1]=nd[rt].ch[1];root=ls;
			pushup(ls);clear(rt);nd[ls].fa=0;
        }else root=nd[rt].ch[1],nd[root].fa=0,clear(rt);
		return;
    }del(nd[rt].ch[nd[rt].val<x],x);
}

查找前驱/后继

注意元素可以不在序列中

int pre(int val){
	int rk=rank(root,val)-1;
	return kth(root,rk);
}
int suf(int val){
	int rk=rank(root,val+1);
	return kth(root,rk);
}

夹带私货(调试用)

可以输出中序遍历,虽然都会写……

void debug(int rt){
	if(nd[rt].ch[0]) debug(nd[rt].ch[0]);
	printf("%d %d ",nd[rt].val,nd[rt].cnt);
	if(nd[rt].ch[1]) debug(nd[rt].ch[1]);
}

用这个模板可以过洛谷上的数据加强版。

To be continued

序列上的 Splay

先来看个题哈:P3391 【模板】文艺平衡树

这题需要支持区间翻转,并输出最终结果。那如果暴力的话是 \(\mathcal{O}(n^2)\) 的,显然 TLE。

这时候,我们考虑把位置作为权值,建立一棵 Splay。如果对 \([l,r]\) 翻转,那么我们就在书上查找 \(l-1\)\(r+1\),然后把 \(l-1\) 转到根,把 \(r+1\) 转到根的右儿子,此时可以发现,根据二叉搜索树的性质,区间 \([l,r]\) 就是以根的右儿子的左儿子为根的子树。

然后我们在这里记一个 \(tag\),之后如果向下访问了,就 \(pushdown\) 即可。

呃,一点小问题,我们知道 \(Splay\) 操作是可以把节点旋转到根的,那怎么旋转到根的右儿子呢?我们可以多加一个 \(gl\) 参数,表示目标点的父亲。

void splay(int rt,int gl){
    while(nd[rt].fa!=gl){
        int p=nd[rt].fa,g=nd[p].fa;
        if(g==gl) rot(rt);
        else if(chk(rt)==chk(p)) rot(p),rot(rt);
        else rot(rt),rot(rt);
    }if(!gl) root=rt;
}

那我们就把这题做完了……

注意,由于我们用到 \(l-1\)\(r+1\),所以需要一个极小值和极大值防止翻转 \([1,n]\) 的时候爆炸~

贴个代码,以示诚意。

Code

#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
using namespace std;
const int MAXN=1e5+10;
int tot,root;
struct Tree{int ch[2],val,siz,fa,rev;}nd[MAXN];
void pushup(int rt){nd[rt].siz=nd[nd[rt].ch[0]].siz+nd[nd[rt].ch[1]].siz+1;}
int chk(int rt){return rt==nd[nd[rt].fa].ch[1];}
void pushdown(int rt){
    if(nd[rt].rev==0) return;
    nd[nd[rt].ch[0]].rev^=1;
    nd[nd[rt].ch[1]].rev^=1;
    nd[rt].rev=0;
    swap(nd[rt].ch[0],nd[rt].ch[1]);
}
void rot(int rt){
	int p=nd[rt].fa,g=nd[p].fa,d=chk(rt);
	nd[g].ch[chk(p)]=rt;nd[rt].fa=g;
	nd[p].ch[d]=nd[rt].ch[d^1];
	nd[nd[rt].ch[d^1]].fa=p;
	nd[rt].ch[d^1]=p;nd[p].fa=rt;
	pushup(p);pushup(rt);
}
void splay(int rt,int gl){
    while(nd[rt].fa!=gl){
        int p=nd[rt].fa,g=nd[p].fa;
        if(g==gl) rot(rt);
        else if(chk(rt)==chk(p)) rot(p),rot(rt);
        else rot(rt),rot(rt);
    }
    if(!gl) root=rt;
}
void ins(int rt,int val,int f){
    if(!rt){
        rt=++tot;nd[rt].val=val;nd[rt].siz=1;
        nd[rt].fa=f;nd[f].ch[val>=nd[f].val]=rt;
        splay(rt,0);return;
    }
    int d=(nd[rt].val<=val);
    ins(nd[rt].ch[d],val,rt);
}
int find(int rt,int k){
	pushdown(rt);
	int cur=nd[nd[rt].ch[0]].siz+1;
	if(k<cur) return find(nd[rt].ch[0],k);
	else if(k==cur) return rt;
	else return find(nd[rt].ch[1],k-cur);
}
int n,m;
void print(int rt){
	pushdown(rt);
    if(nd[rt].ch[0]) print(nd[rt].ch[0]);
    if(nd[rt].val-1>=1&&nd[rt].val-1<=n) printf("%d ",nd[rt].val-1);
    if(nd[rt].ch[1]) print(nd[rt].ch[1]);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n+2;i++) ins(root,i,0);
    //这里和讲的不太一样,由于查找 0 比较麻烦,所以还是将整个数组向后移动一个,把 1 和 n+2 当成边界
    int l,r;
    while(m--){
        scanf("%d%d",&l,&r);
        l=find(root,l);r=find(root,r+2);
        splay(l,0); splay(r,l);
        nd[nd[nd[root].ch[1]].ch[0]].rev^=1;
    }
	print(root);
}

推荐阅读