参考论文《运用伸展树解决数列维护问题》
/**//* 题意:实现插入、删除、查询最大子段和、求和 翻转,置为一样
每次修改后splay到根(这样来更新) update操作里面需要有push_down先 每次访问节点时都需要push_down */ #include<cstdio> #include<cstring> #include<algorithm> using namespace std;
const int MAXN = 500000*2; const int INF = 1000000000;
inline int max(int a,int b){return a>b?a:b;} inline int min(int a,int b){return a<b?a:b;}
struct Node { int ml,mc,mr; int size,sum,val; bool rev,same; Node *ch[2],*pre; };
struct Splay { Node *root,*null; Node data[MAXN],*Store[MAXN]; int top,cnt;
Node *newNode(int val) { Node *p; if(top)p=Store[top--]; else p=&data[cnt++]; p->sum=p->val=val; p->ml=p->mc=p->mr=val; p->size=1; p->ch[0]=p->ch[1]=p->pre=null; p->rev=p->same=false; return p; } void init() { top=cnt=0; null = newNode(-INF); null->size=null->sum=0; root=newNode(-INF); root->ch[1]=newNode(-INF); root->ch[1]->pre=root; update(root); } void update(Node *p) { if(p==null)return; //需要先push_down push_down(p); push_down(p->ch[0]); push_down(p->ch[1]);
p->size=p->ch[0]->size+p->ch[1]->size+1; p->sum=p->ch[0]->sum+p->ch[1]->sum+p->val;
p->ml=max(p->ch[0]->ml,p->ch[0]->sum+p->val+max(0,p->ch[1]->ml)); p->mr=max(p->ch[1]->mr,p->ch[1]->sum+p->val+max(0,p->ch[0]->mr));
p->mc=max(p->ch[0]->mc,p->ch[1]->mc); p->mc=max(p->mc,max(p->ch[0]->mr+p->ch[1]->ml,0)+p->val); p->mc=max(p->mc,max(p->ch[0]->mr,p->ch[1]->ml)+p->val); } void push_down(Node *p) { if(p==null)return; if(p->rev) { p->rev=false; p->ch[0]->rev^=1; p->ch[1]->rev^=1; swap(p->ch[0],p->ch[1]); swap(p->ml,p->mr); } if(p->same) { p->same=false; p->ch[0]->same=p->ch[1]->same=true; p->ch[0]->val=p->ch[1]->val=p->val; p->sum=p->val*p->size; p->ml=p->mc=p->mr=p->val*p->size; if(p->val<0) p->ml=p->mc=p->mr=p->val; } } void rotate(Node *x,int c) { Node *y=x->pre; push_down(y); push_down(x); y->ch[!c]=x->ch[c]; if(x->ch[c]!=null)x->ch[c]->pre=y; x->pre=y->pre; if(y->pre!=null) { if(y->pre->ch[0]==y)y->pre->ch[0]=x; else y->pre->ch[1]=x; } x->ch[c]=y; y->pre=x; update(y); if(y==root)root=x; } void splay(Node *x,Node *f) { if(x==null)return; for(push_down(x);x->pre!=f;) { if(x->pre->pre==f) { if(x->pre->ch[0]==x)rotate(x,1); else rotate(x,0); } else { Node *y=x->pre,*z=y->pre; if(z->ch[0]==y) { if(y->ch[0]==x)rotate(y,1),rotate(x,1); else rotate(x,0),rotate(x,1); } else { if(y->ch[1]==x)rotate(y,0),rotate(x,0); else rotate(x,1),rotate(x,0); } } } update(x); } void select(int k,Node *f) { Node *t; for(t=root;;) { push_down(t); int size=t->ch[0]->size; if(k==size)break; if(k<size)t=t->ch[0]; else k-=size+1,t=t->ch[1]; } splay(t,f); } Node *build(int left,int right,int *ary) { if(left>right)return null; int mid=(left+right)>>1; Node *p=newNode(ary[mid]); p->ch[0]=build(left,mid-1,ary); if(p->ch[0]!=null)p->ch[0]->pre=p; p->ch[1]=build(mid+1,right,ary); if(p->ch[1]!=null)p->ch[1]->pre=p; update(p); return p; } void insert(int pos,int *ary,int n) { select(pos-1,null); select(pos,root); Node *p=build(0,n-1,ary); root->ch[1]->ch[0]=p; p->pre=root->ch[1]; splay(p,null); } void del(int start,int end)//[start,end) { select(start-1,null); select(end,root); Node *p=root->ch[1]->ch[0]; root->ch[1]->ch[0]=null; splay(root->ch[1],null); recyle(p);// } void recyle(Node *p) { if(p==null)return; recyle(p->ch[0]); recyle(p->ch[1]); Store[++top]=p; } int getSum(int start,int end)//[start,end) { select(start-1,null); select(end,root); return root->ch[1]->ch[0]->sum; } int maxSum(int start,int end) { select(start-1,null); select(end,root); return root->ch[1]->ch[0]->mc; } void makeSame(int start,int end,int c) { select(start-1,null); select(end,root); Node *p=root->ch[1]->ch[0]; p->same=true; p->val=c; splay(p,null); } void reverse(int start,int end) { select(start-1,null); select(end,root); Node *p=root->ch[1]->ch[0]; p->rev^=1; splay(p,null); } void print(Node *p) { if(p==null)return; push_down(p); print(p->ch[0]); printf("%d %d %d %d %d\n",p->val,p->ml,p->mc,p->mr,p->mc); print(p->ch[1]); } }splay;
int ary[MAXN];
int main() { //freopen("in","r",stdin); int T; for(scanf("%d",&T);T--;) { int N,Q,n,x,y; char cmd[20]; scanf("%d%d",&N,&Q); for(int i=0;i<N;i++) scanf("%d",&ary[i]); splay.init(); splay.insert(1,ary,N); for(;Q--;) { scanf("%s",cmd); if(strcmp(cmd,"GET-SUM")==0) { scanf("%d%d",&x,&n); printf("%d\n",splay.getSum(x,x+n)); } if(strcmp(cmd,"MAX-SUM")==0) { printf("%d\n",splay.maxSum(1,splay.root->size-1)); } if(strcmp(cmd,"INSERT")==0) { scanf("%d%d",&x,&n); x++; for(int i=0;i<n;i++) scanf("%d",&ary[i]); splay.insert(x,ary,n); } if(strcmp(cmd,"DELETE")==0) { scanf("%d%d",&x,&n); splay.del(x,x+n); } if(strcmp(cmd,"REVERSE")==0) { scanf("%d%d",&x,&n); splay.reverse(x,x+n); } if(strcmp(cmd,"MAKE-SAME")==0) { scanf("%d%d%d",&x,&n,&y); splay.makeSame(x,x+n,y); } } } return 0; }
|