|
- #include <iostream>
- using namespace std;
- struct node {
- node* ch[2];
- int v;
- int s;
- int cmp(int x) const
- {
- if (x == v)
- return -1;
- return x < v ? 0 : 1;
- }
- };
- node* null = new node();
- node* newnode(int x)
- {
- node* tmp = new node();
- tmp->ch[0] = null;
- tmp->ch[1] = null;
- tmp->v = x;
- tmp->s = 1;
- return tmp;
- }
- int rank(node* o, int x)
- {
- if (o == null)
- return -1;
- if (o->cmp(x) == -1)
- return 1;
- else if (o->cmp(x) == 0)
- return rank(o->ch[0], x);
- else
- return o->ch[0]->s + rank(o->ch[1], x);
- }
- node* find_kth(node* o, int k)
- {
- if (k == 1) {
- while (o->ch[0] != null)
- o = o->ch[0];
- return o;
- }
- else {
- if (o->ch[0]->s >= k)
- return find_kth(o->ch[0], k);
- else
- return find_kth(o->ch[1], k - o->ch[0]->s);
- }
- }
- node* precursor(node* o, int x)
- {
- return find_kth(o, rank(o, x) - 1);
- }
- node* successor(node* o, int x)
- {
- return find_kth(o, rank(o, x) + 1);
- }
- void rotate(node* &o, int d)
- {
- node* k = o->ch[d ^ 1];
- o->ch[d ^ 1] = k->ch[d];
- k->ch[d] = o;
- k->s = o->s;
- o->s = o->ch[0]->s + o->ch[1]->s;
- o = k;
- }
- void maintain(node* o, bool d)
- {
- if (!d) {
- if (o->ch[0]->ch[0]->s > o->ch[1]->s)
- rotate(o, 1);
- else if (o->ch[0]->ch[1]->s > o->ch[1]->s) {
- rotate(o->ch[0], 0);
- rotate(o, 1);
- }
- else
- return;
- }
- else {
- if (o->ch[1]->ch[1]->s > o->ch[0]->s)
- rotate(o, 0);
- else if (o->ch[1]->ch[0]->s > o->ch[0]->s) {
- rotate(o->ch[1], 1);
- rotate(o, 0);
- }
- else
- return;
- }
- cout << "#";
- maintain(o->ch[0], false);
- maintain(o->ch[1], true);
- maintain(o, false);
- maintain(o, true);
- }
- void insert(node* &o, int x)
- {
- if (o == null)
- o = newnode(x);
- else {
- if (o->cmp(x) == 0)
- insert(o->ch[0], x);
- else
- insert(o->ch[1], x);
- ++o->s;
- maintain(o, false);
- maintain(o, true);
- }
- }
- void erase(node* &o, int x)
- {
- if (o->cmp(x) == -1) {
- if (o->ch[0] == null && o->ch[1] == null)
- o = null;
- else if (o->ch[0] == null)
- o = o->ch[1];
- else if (o->ch[1] == null)
- o = o->ch[0];
- else {
- node* tmp = find_kth(o->ch[1], 1);
- o->v = tmp->v;
- erase(o->ch[1], tmp->v);
- --o->s;
- }
- }
- else {
- if (o->cmp(x) == 0)
- erase(o->ch[0], x);
- else {
- erase(o->ch[1], x);
- }
- --o->s;
- }
- }
- node* T = null;
- int n;
- int main()
- {
- cin >> n;
- null->ch[0] = null;
- null->ch[1] = null;
- for (int i = 0, p; i != n; ++i) {
- int x;
- cin >> p >> x;
- if (p == 1)
- insert(T, x);
- if (p == 2)
- erase(T, x);
- if (p == 3)
- cout << rank(T, x) << endl;
- if (p == 4)
- cout << find_kth(T, x)->v << endl;
- if (p == 5)
- cout << precursor(T, x)->v << endl;
- if (p == 6)
- cout << successor(T, x)->v << endl;
- }
- return 0;
- }
复制代码 |
|