题意为给定n(n<=100000)个数字,然后给出m(m<=50000)个询问,每次询问在[i,j]区间内的第k大的数字是多少。 用SBT的方法做就是先把m个询问按照起点从小到大排序。对于每个询问区间[i,j],要调整SBT中的元素,使得树中只有[i,j]中的全部元素,多余元素删除,然后就可以求得树中第k大的数。代码就当模板了。哪里有错误希望大家指出,谢谢。 SBT理论知识详见陈启峰大牛的论文。 这里有论文的中文版。
POJ2761_SBT #include<iostream> #include<algorithm> using namespace std; struct SBT { int left,right,s,key,cnt; void init() { left=right=0; s=1; } }a[100010]; int tol=0; int root=0; void left_rotate(int &t)//左旋 { int k=a[t].right; a[t].right=a[k].left; a[k].left=t; a[k].s=a[t].s; a[t].s=a[a[t].left].s+a[a[t].right].s+1; t=k; return ; } void right_rotate(int &t)//右旋 { int k=a[t].left; a[t].left=a[k].right; a[k].right=t; a[k].s=a[t].s; a[t].s=a[a[t].left].s+a[a[t].right].s+1; t=k; return ; } void maintain(int &t,bool flag)//保持 { if(flag==0) { if(a[a[a[t].left].left].s>a[a[t].right].s) right_rotate(t); else if(a[a[a[t].left].right].s>a[a[t].right].s) { left_rotate(a[t].left); right_rotate(t); } else return ; } else { if(a[a[a[t].right].right].s>a[a[t].left].s) left_rotate(t); else if(a[a[a[t].right].left].s>a[a[t].left].s) { right_rotate(a[t].right); left_rotate(t); } else return ; } maintain(a[t].left,0); maintain(a[t].right,1); maintain(t,0); maintain(t,1); return ; } void insert(int &t,int v)//插入 { if(t==0) { t=++tol; a[t].init(); a[t].key=v; } else { a[t].s++; if(v<a[t].key) insert(a[t].left,v); else insert(a[t].right,v); maintain(t,v>=a[t].key); } return ; } int del(int &t,int v)//删除 { if(!t) return 0; a[t].s--; if(v==a[t].key||v<a[t].key&&!a[t].left||v>a[t].key&&!a[t].right) { if(a[t].left&&a[t].right) { int p=del(a[t].left,v+1); a[t].key=a[p].key; return p; } else { int p=t; t=a[t].left+a[t].right; return p; } } else return del(v<a[t].key?a[t].left:a[t].right,v); } int find(int t,int k)//寻找第k小的数 { if(k<=a[a[t].left].s) return find(a[t].left,k); else if(k>a[a[t].left].s+1) return find(a[t].right,k-a[a[t].left].s-1); return a[t].key; } int getmax(int t)//返回最大的节点 { while(a[t].right) t=a[t].right; return t; } int getmin(int t)//返回最小的节点 { while(a[t].left) t=a[t].left; return t; } int val[100010]; struct que { int s,e,th,idx,ans; }q[50010]; bool cmp(que a,que b) { if(a.s==b.s) return a.e<b.e; return a.s<b.s; } bool cmp2(que a,que b) { return a.idx<b.idx; } int n,m; int main() { tol=0; root=0; int i,j,k; scanf("%d%d",&n,&m); for(i=1;i<=n;i++) scanf("%d",&val[i]); for(i=0;i<m;i++) { scanf("%d%d%d",&q[i].s,&q[i].e,&q[i].th); q[i].idx=i; } sort(q,q+m,cmp); for(i=q[0].s;i<=q[0].e;i++) insert(root,val[i]); q[0].ans=find(root,q[0].th); for(i=1;i<m;i++) { if(q[i].s>=q[i-1].e) { for(j=q[i-1].s;j<=q[i-1].e;j++) del(root,val[j]); for(j=q[i].s;j<=q[i].e;j++) insert(root,val[j]); } else if(q[i].s<=q[i-1].e&&q[i].e>q[i-1].e) { for(j=q[i-1].s;j<q[i].s;j++) del(root,val[j]); for(j=q[i-1].e+1;j<=q[i].e;j++) insert(root,val[j]); } else { for(j=q[i-1].s;j<q[i].s;j++) del(root,val[j]); for(j=q[i].e+1;j<=q[i-1].e;j++) del(root,val[j]); } q[i].ans=find(root,q[i].th); } sort(q,q+m,cmp2); for(i=0;i<m;i++) printf("%d\n",q[i].ans); return 0; }
|