平衡树

BST(Binary Search Tree):二叉搜索树

满足:

  1. 当前节点的左子树中的任何一个点的权值$<$当前节点的权值
  2. 当前节点的右子树中的任何一个点的权值$>$当前节点的权值

一般保证无重复权值,若有,可以在每个节点上记录当前权值的个数

可以发现,BST的中序遍历是有序的,因此BST的作用就是动态维护有序集合

操作:

  1. 插入
  2. 删除
  3. 找前驱(中序遍历中的前一个位置)和后继(中序遍历中的后一个位置)
  4. 找最大和最小

平衡树就是特殊的二叉搜索树

普通平衡树(Treap)

Treap——Tree+heap

使用堆的性质优化二叉搜索树,使用左旋和右旋让二叉搜索树保持堆的性质使二叉树的层数尽量小,减少每次操作的复杂度。

结点的保存:

1
2
3
4
5
struct Node{
int l,r;//左右儿子
int key,val;//二叉搜索树中的值和堆中的值(随机)
int cnt,size;//次数,自身所在子树大小
}tr[N];

初始化:新建两个哨兵结点,初始值为$-\infin,+\infin$

核心操作:旋转:左旋(zag),右旋(zig)

操作:

  1. 新建节点

    1
    2
    3
    4
    5
    6
    int get_node(int key){
    tr[++idx].key=key;
    tr[idx].val=rand();
    tr[idx].cnt=tr[idx].size=1;
    return idx;
    }
  2. pushup(计算size)

    1
    2
    3
    void pushup(int p){
    tr[p].size=tr[tr[p].l].size+tr[tr[p].r].size+tr[p].cnt;
    }
  3. 左旋右旋

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    void zig(int &p){//一定要传引用,此时p代表指向根节点的指针(根会改变)
    int q=tr[p].l;
    tr[p].l=tr[q].r,tr[q].r=p,p=q;//右挂左,拧左,改变指向根节点的指针
    pushup(tr[p].r),pushup(p);
    }

    void zag(int &p){
    int q=tr[p].r;
    tr[p].r=tr[q].l,tr[q].l=p,p=q;
    pushup(tr[p].l),pushup(p);
    }
  4. 初始化(新建哨兵结点)

    1
    2
    3
    4
    5
    6
    void build(){
    get_node(-INF),get_node(INF);
    root=1,tr[1].r=2;
    pushup(root);
    if (tr[1].val<tr[2].val)zag(root);//可能一开始就不符合堆的性质,旋转
    }
  5. 插入:与二叉搜索树的插入相同,但是插入完后需要旋转保持堆的性质

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    void insert(int &p,int key){//更新结点时也要更新祖先节点的相应值,所以也要传引用
    if(!p)p=get_node(key);
    else if(tr[p].key==key)tr[p].cnt++;
    else if(tr[p].key>key){
    insert(tr[p].l,key);
    if(tr[tr[p].l].val>tr[p].val)zig(p);
    }
    else{
    insert(tr[p].r,key);
    if(tr[tr[p].r].val>tr[p].val)zag(p);
    }
    pushup(p);
    }
  6. 删除

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    void remove(int &p,int key){
    if(!p)return;
    if(tr[p].key==key){
    if(tr[p].cnt>1)tr[p].cnt--;
    else if(tr[p].l||tr[p].r){
    if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val){
    zig(p);
    remove(tr[p].r,key);
    }
    else{
    zag(p);
    remove(tr[p].l,key);
    }
    }
    else p=0;
    }
    else if(tr[p].key>key){
    remove(tr[p].l,key);
    }
    else{
    remove(tr[p].r,key);
    }

    pushup(p);
    }
  7. 通过数值找排名

    注意:因为有哨兵结点的存在,最终求的排名比实际多一位

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    int get_rank_by_key(int p,int key){
    if(!p)return 0;//找不到
    if(tr[p].key==key)return tr[tr[p].l].size+1;
    if(tr[p].key>key)return get_rank_by_key(tr[p].l,key);
    return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
    }

    int main(){
    ...
    get_rank_by_key(root,key)-1;
    ...
    }
  8. 通过排名找数值

    注意:因为有哨兵结点的存在,查找时应查的排名应加一

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    int get_key_by_rank(int p,int rank){
    if(!p)return INF;
    if(tr[tr[p].l].size>=rank)return get_key_by_rank(tr[p].l,rank);
    if(tr[tr[p].l].size+tr[p].cnt>=rank)return tr[p].key;
    return get_key_by_rank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
    }

    int main(){
    ...
    get_key_by_rank(root,rank+1);
    ...
    }
  9. 找前驱(严格小于k的最大数)

    1
    2
    3
    4
    5
    int get_prev(int p,int key){
    if(!p)return -INF;
    if(tr[p].key>=key)return get_prev(tr[p].l,key);
    return max(tr[p].key,get_prev(tr[p].r,key));
    }
  10. 找后继(严格大于k的最小数)

    与找前驱类似

    1
    2
    3
    4
    5
    int get_next(int p,int key){
    if(!p)return INF;
    if(tr[p].key<=key)return get_next(tr[p].r,key);
    return min(tr[p].key,get_next(tr[p].l,key));
    }

总模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <bits/stdc++.h>
using namespace std;

const int N=100010,INF=0x3f3f3f3f;
int n;
struct Node
{
int l, r;
int key, val;
int cnt, size;
}tr[N];
int root,idx;

int get_node(int key){
tr[++idx].key=key;
tr[idx].val=rand();
tr[idx].cnt=tr[idx].size=1;
return idx;
}

void pushup(int p){
tr[p].size=tr[tr[p].l].size+tr[tr[p].r].size+tr[p].cnt;
}

void zig(int &p){
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[p].r),pushup(p);
}

void zag(int &p){
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
pushup(tr[p].l),pushup(p);
}

void build(){
get_node(-INF),get_node(INF);
root=1,tr[1].r=2;
pushup(root);
if (tr[1].val<tr[2].val)zag(root);
}

void insert(int &p,int key){
if(!p)p=get_node(key);
else if(tr[p].key==key)tr[p].cnt++;
else if(tr[p].key>key){
insert(tr[p].l,key);
if(tr[tr[p].l].val>tr[p].val)zig(p);
}
else{
insert(tr[p].r,key);
if(tr[tr[p].r].val>tr[p].val)zag(p);
}
pushup(p);
}

void remove(int &p,int key){
if(!p)return;
if(tr[p].key==key){
if(tr[p].cnt>1)tr[p].cnt--;
else if(tr[p].l||tr[p].r){
if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val){
zig(p);
remove(tr[p].r,key);
}
else{
zag(p);
remove(tr[p].l,key);
}
}
else p=0;
}
else if(tr[p].key>key){
remove(tr[p].l,key);
}
else{
remove(tr[p].r,key);
}

pushup(p);
}

int get_rank_by_key(int p,int key){
if(!p)return 0;
if(tr[p].key==key)return tr[tr[p].l].size+1;
if(tr[p].key>key)return get_rank_by_key(tr[p].l,key);
return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
}

int get_key_by_rank(int p,int rank){
if(!p)return INF;
if(tr[tr[p].l].size>=rank)return get_key_by_rank(tr[p].l,rank);
if(tr[tr[p].l].size+tr[p].cnt>=rank)return tr[p].key;
return get_key_by_rank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
}

int get_prev(int p,int key){
if(!p)return -INF;
if(tr[p].key>=key)return get_prev(tr[p].l,key);
return max(tr[p].key,get_prev(tr[p].r,key));
}

int get_next(int p,int key){
if(!p)return INF;
if(tr[p].key<=key)return get_next(tr[p].r,key);
return min(tr[p].key,get_next(tr[p].l,key));
}

int main()
{
build();
cin >> n;
while(n--)
{
int op,x;
cin >> op >> x;
if(op==1)insert(root,x);
else if(op==2)remove(root,x);
else if(op==3)cout << get_rank_by_key(root,x)-1 << endl;
else if(op==4)cout << get_key_by_rank(root,x+1) << endl;
else if(op==5)cout << get_prev(root,x) << endl;
else cout << get_next(root,x) << endl;
}
return 0;
}

文艺平衡树(Splay)

支持序列中区间翻转的平衡二叉树,保证中序遍历为当前序列的顺序

存储:

1
2
3
4
5
6
7
8
9
struct Node{
int l,r,p,v;//左右子节点,父节点,编号
int size,flag;//子树大小,懒标记

void init(int _v,int _p){
v=_v,p=_p;
size=1;
}
}tr[N];

核心:每操作一个结点(插入,查询),都将该结点旋转到树根

维护信息(size,懒标记flag(记录翻转)):

  1. pushup:维护信息(当前点的size等于左右儿子的size之和加1)

    旋转之后

  2. pushdown:下传懒标记(翻转左右子树后标记下传)

    递归之前

核心操作:

  1. 旋转rotate(x)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    void rotate(int x){
    int y=tr[x].p,z=tr[y].p;
    int k=tr[y].s[1]==x;//k=0表示x是y的左儿子;k=1表示x是y的右儿子
    tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
    //改变y和z,x和z的关系,z的y原先所在儿子变为x,x的父节点变为z
    tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
    //改变y和x的相反儿子的关系,y的x原先所在儿子变为x的相反儿子,x的原先相反儿子的父节点变为y
    tr[x].s[k^1]=y,tr[y].p=x;
    //改变x和y的关系,x的相反儿子变为y,y的父节点变为x
    pushup(y),pushup(x);
    }
  2. splay(x,k):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    void splay(int x,int k){
    while(tr[x].p!=k){
    int y=tr[x].p,z=tr[y].p;
    if(z!=k){
    if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
    else rotate(y);
    }
    rotate(x);
    }
    if(!k)root=x;
    }
  3. 插入

    1
    2
    3
    4
    5
    6
    7
    8
    void insert(int v){
    int u=root,p=0;
    while(u)p=u,u=tr[u].s[v>tr[u].v];
    u=++idx;
    if(p)tr[p].s[v>tr[p].v]=u;
    tr[u].init(v,p);
    splay(u,0);
    }
  4. 删除

  5. 通过排名找数值

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    int get_k(int k){
    int u=root;
    while(true){
    pushdown(u);
    if(tr[tr[u].s[0]].size>=k)u=tr[u].s[0];
    else if(tr[tr[u].s[0]].size+1==k)return u;
    else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
    }
    return -1;
    }
  6. 输出序列(中序遍历)

    1
    2
    3
    4
    5
    6
    void output(int u){
    pushdown(u);
    if(tr[u].s[0])output(tr[u].s[0]);
    if(tr[u].v>=1&&tr[u].v<=n)cout << tr[u].v << ' ';
    if(tr[u].s[1])output(tr[u].s[1]);
    }

其余操作参考Treap

总模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<bits/stdc++.h>
using namespace std;

const int N=100010;

int n,m;
struct Node{
int s[2],p,v;//左右节点,父节点,编号
int size,flag;

void init(int _v,int _p){
v=_v,p=_p;
size=1;
}
}tr[N];
int root,idx;

void pushup(int x){
tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
}

void pushdown(int x){
if(tr[x].flag){
swap(tr[x].s[0],tr[x].s[1]);
tr[tr[x].s[0]].flag^=1,tr[tr[x].s[1]].flag^=1;
tr[x].flag=0;
}
}

void rotate(int x){
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;//k=0表示x是y的左儿子;k=1表示x是y的右儿子
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}

void splay(int x,int k){
while(tr[x].p!=k){
int y=tr[x].p,z=tr[y].p;
if(z!=k){
if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(!k)root=x;
}

void insert(int v){
int u=root,p=0;
while(u)p=u,u=tr[u].s[v>tr[u].v];
u=++idx;
if(p)tr[p].s[v>tr[p].v]=u;
tr[u].init(v,p);
splay(u,0);
}

int get_k(int k){
int u=root;
while(true){
pushdown(u);
if(tr[tr[u].s[0]].size>=k)u=tr[u].s[0];
else if(tr[tr[u].s[0]].size+1==k)return u;
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}

void output(int u){
pushdown(u);
if(tr[u].s[0])output(tr[u].s[0]);
if(tr[u].v>=1&&tr[u].v<=n)cout << tr[u].v << ' ';
if(tr[u].s[1])output(tr[u].s[1]);
}

int main(){
cin >> n >> m;
for(int i=0;i<=n+1;i++){//0和n+1是哨兵结点
insert(i);
}
while(m--){
int l,r;
cin >> l >> r;
l=get_k(l),r=get_k(r+2);
splay(l,0),splay(r,l);
tr[tr[r].s[0]].flag^=1;
}
output(root);
return 0;
}

FHQ-Treap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include<bits/stdc++.h>
using namespace std;

const int N=1e5+5;

struct node{
int l,r,sz,val,rd;
}tr[N];
int cnt;

int root,x,y,z;

void pushup(int u){
tr[u].sz=tr[tr[u].l].sz+tr[tr[u].r].sz+1;
}

void split(int u,int k,int &x,int &y){
if(!u){
x=y=0;
return;
}
if(tr[u].val<=k){
x=u;
split(tr[u].r,k,tr[u].r,y);
}
else{
y=u;
split(tr[u].l,k,x,tr[u].l);
}
pushup(u);
}

int merge(int u,int v){
if(!u||!v)return u|v;
if(tr[u].rd<tr[v].rd){
tr[u].r=merge(tr[u].r,v);
pushup(u);
return u;
}
else{
tr[v].l=merge(u,tr[v].l);
pushup(v);
return v;
}
}

int new_node(int a){
tr[++cnt].val=a,tr[cnt].sz=1,tr[cnt].rd=rand();
return cnt;
}

void insert(int a){
split(root,a,x,y);
root=merge(merge(x,new_node(a)),y);
}

void delete_all(int a){
split(root,a,x,y);
split(x,a-1,x,y);
root=merge(x,y);
}

void delete_one(int a){
split(root,a,x,y);
split(x,a-1,x,z);
z=merge(tr[z].l,tr[z].r);
root=merge(merge(x,z),y);
}

int get_key_by_rank(int u,int k){
while(true){
if(k<=tr[tr[u].l].sz)u=tr[u].l;
else if(k==tr[tr[u].l].sz+1)return u;
else{
k=k-tr[tr[u].l].sz-1;
u=tr[u].r;
}
}
}

int get_rank_by_key(int a){
split(root,a-1,x,y);
int k=tr[x].sz+1;
root=merge(x,y);
return k;
}

int findpre(int a){
split(root,a-1,x,y);
int tmp=tr[get_key_by_rank(x,tr[x].sz)].val;
root=merge(x,y);
return tmp;
}

int findnext(int a){
split(root,a,x,y);
int tmp=tr[get_key_by_rank(y,1)].val;
root=merge(x,y);
return tmp;
}

int n;

int main(){
cin >> n;
for(int i=1;i<=n;i++){
int opt,a;
cin >> opt >> a;
if(opt==1){
insert(a);
}
if(opt==2){
delete_one(a);
}
if(opt==3){
cout << get_rank_by_key(a) << endl;
}
if(opt==4){
cout << tr[get_key_by_rank(root,a)].val << endl;
}
if(opt==5){
cout << findpre(a) << endl;
}
if(opt==6){
cout << findnext(a) << endl;
}
}
return 0;
}