Orz zkw!!!
最近看完了《统计的力量》……觉得这实在是太神了……原来线段树可以这么写……
zkw线段树的思想:先将线段长度N变为2的整数次方,使线段树成为满二叉树,然后就可以通过各种位运算直接链接到某个结点,不必递归了,因此大大减小了常数……
本沙茶利用zkw线段树在BZOJ1756和1798上都刷到了rank3……
代码:
<1>BZOJ1756:
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
#define re1(i, n) for (int i=1; i<=n; i++)
#define re2(i, l, r) for (int i=l; i<r; i++)
#define re3(i, l, r) for (int i=l; i<=r; i++)
#define rre(i, n) for (int i=n-1; i>=0; i--)
#define rre1(i, n) for (int i=n; i>0; i--)
#define rre2(i, r, l) for (int i=r-1; i>=l; i--)
#define rre3(i, r, l) for (int i=r; i>=l; i--)
#define ll long long
const int MAXN = (1 << 19) + 10, INF = ~0U >> 2;
struct node {
int sum, lv, rv, v;
} T[MAXN << 1];
int n, N, A[MAXN], Z[MAXN], res;
inline int get_int()
{
int x; char ch; bool FF;
while ((ch = getchar()) < 48 && ch != '-') ;
if (ch == '-') {FF = 1; x = 0;} else {FF = 0; x = ch - 48;}
while ((ch = getchar()) >= 48) x = x * 10 + ch - 48;
if (FF) x = -x; return x;
}
void prepare()
{
N = n << 1;
re2(i, n, N) T[i].sum = T[i].lv = T[i].rv = T[i].v = A[i - n];
for (int i=0; (1<<i)<=n; i++) Z[1 << i] = i;
int lch, rch, _;
rre2(i, n, 1) {
lch = i << 1; rch = lch ^ 1;
T[i].sum = T[lch].sum + T[rch].sum;
T[i].lv = (_ = T[lch].sum + T[rch].lv) >= T[lch].lv ? _ : T[lch].lv;
T[i].rv = (_ = T[rch].sum + T[lch].rv) >= T[rch].rv ? _ : T[rch].rv;
T[i].v = T[lch].v >= T[rch].v ? T[lch].v : T[rch].v;
if ((_ = T[lch].rv + T[rch].lv) >= T[i].v) T[i].v = _;
}
}
void opr0(int pos, int x)
{
int i = pos + n; T[i].sum = T[i].lv = T[i].rv = T[i].v = x; int _, __ = x - A[pos], lch, rch; A[pos] = x;
for (i>>=1; i; i>>=1) {
lch = i << 1; rch = lch ^ 1;
T[i].sum += __;
T[i].lv = (_ = T[lch].sum + T[rch].lv) >= T[lch].lv ? _ : T[lch].lv;
T[i].rv = (_ = T[rch].sum + T[lch].rv) >= T[rch].rv ? _ : T[rch].rv;
T[i].v = T[lch].v >= T[rch].v ? T[lch].v : T[rch].v;
if ((_ = T[lch].rv + T[rch].lv) >= T[i].v) T[i].v = _;
}
}
void opr1(int l, int r)
{
int sum0 = 0, l0, i, _; l |= n; r |= n; r++; res = -INF;
for (; l0=l, (l+=l&-l)<=r; ) {
i = l0 / (l0 & -l0);
if (T[i].v > res) res = T[i].v;
if ((_ = sum0 + T[i].lv) > res) res = _;
sum0 += T[i].sum; if (T[i].rv > sum0) sum0 = T[i].rv;
}
int s = (l0 & -l0) >> 1, z = Z[s];
for (; l0<r; s>>=1, z--) if ((l = l0 + s) <= r) {
i = l0 >> z;
if (T[i].v > res) res = T[i].v;
if ((_ = sum0 + T[i].lv) > res) res = _;
sum0 += T[i].sum; if (T[i].rv > sum0) sum0 = T[i].rv;
l0 = l;
}
}
int main()
{
int n0, m, x, y, z;
n0 = get_int(); m = get_int(); re(i, n0) A[i] = get_int(); for (n=1; n<n0; n<<=1) ;
prepare();
re(i, m) {
x = get_int(); y = get_int(); z = get_int();
if (x == 1) {if (y > z) {x = y; y = z; z = x;} opr1(--y, --z); printf("%d\n", res);} else opr0(--y, z);
}
return 0;
}
<2>BZOJ1798:
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
#define re1(i, n) for (int i=1; i<=n; i++)
#define re2(i, l, r) for (int i=l; i<r; i++)
#define re3(i, l, r) for (int i=l; i<=r; i++)
#define rre(i, n) for (int i=n-1; i>=0; i--)
#define rre1(i, n) for (int i=n; i>0; i--)
#define rre2(i, r, l) for (int i=r-1; i>=l; i--)
#define rre3(i, r, l) for (int i=r; i>=l; i--)
#define ll long long
const int MAXN = (1 << 17) + 10;
struct node {
ll mr0, mr1, sum;
int len;
} T[MAXN << 1];
int n, s, N, A[MAXN];
ll MOD, res;
inline int get_int()
{
char ch; int x;
while ((ch = getchar()) < 48) ;
x = ch - 48; while ((ch = getchar()) > 47) x = x * 10 + ch - 48;
return x;
}
void prepare()
{
N = n << 1; int lch, rch;
re2(i, n, N) {T[i].mr0 = 1; T[i].sum = A[i - n] % MOD; T[i].len = 0;}
rre2(i, n, 1) {
lch = i << 1; rch = lch ^ 1; T[i].len = T[lch].len + 1;
T[i].mr0 = 1; T[i].sum = T[lch].sum + T[rch].sum; if (T[i].sum >= MOD) T[i].sum -= MOD;
}
}
inline void dm(int i)
{
int lch = i << 1, rch = lch ^ 1; ll c0;
if ((c0 = T[i].mr0) ^ 1) {
T[i].mr0 = 1;
T[lch].mr0 = T[lch].mr0 * c0 % MOD; T[lch].mr1 = T[lch].mr1 * c0 % MOD; T[lch].sum = T[lch].sum * c0 % MOD;
T[rch].mr0 = T[rch].mr0 * c0 % MOD; T[rch].mr1 = T[rch].mr1 * c0 % MOD; T[rch].sum = T[rch].sum * c0 % MOD;
}
if (c0 = T[i].mr1) {
T[i].mr1 = 0;
T[lch].mr1 += c0; if (T[lch].mr1 >= MOD) T[lch].mr1 -= MOD; T[lch].sum = (T[lch].sum + (c0 << T[lch].len)) % MOD;
T[rch].mr1 += c0; if (T[rch].mr1 >= MOD) T[rch].mr1 -= MOD; T[rch].sum = (T[rch].sum + (c0 << T[rch].len)) % MOD;
}
}
void opr0(int l, int r, ll c)
{
int k, l0 = l | n, r0 = r | n, i, j, lch, rch; ll c0;
for (k=s-1; k&&(i=l0>>k)==r0>>k; k--) dm(i);
for (int k0=k; ((i=l0>>k0)<<k0)^l0; k0--) dm(i);
for (int k0=k; (((i=r0>>k0)<<k0)|((1<<k0)-1))^r0; k0--) dm(i);
r0++; int l1;
for (; l1=l0, (l0+=l0&-l0)<=r0; ) {
i = l1 / (l1 & -l1);
T[i].mr0 = T[i].mr0 * c % MOD; T[i].mr1 = T[i].mr1 * c % MOD; T[i].sum = T[i].sum * c % MOD;
}
int _ = (l1 & -l1) >> 1, z = __builtin_ctz(_);
for (; l1<r0; _>>=1, z--) if (l1 + _ <= r0) {
i = l1 >> z;
T[i].mr0 = T[i].mr0 * c % MOD; T[i].mr1 = T[i].mr1 * c % MOD; T[i].sum = T[i].sum * c % MOD;
l1 += _;
}
l0 = l | n; r0 = r | n;
for (k=0; (i=l0>>k)^(j=r0>>k); k++) {
if ((i << k) ^ l0) {T[i].sum = T[i << 1].sum + T[(i << 1) ^ 1].sum; if (T[i].sum >= MOD) T[i].sum -= MOD;}
if (((j << k) | ((1 << k) - 1)) ^ r0) {T[j].sum = T[j << 1].sum + T[(j << 1) ^ 1].sum; if (T[j].sum >= MOD) T[j].sum -= MOD;}
}
for (; k<s; k++) {
i = l0 >> k;
if ((i << k) ^ l0 || ((j << k) | ((1 << k) - 1)) ^ r0) {T[i].sum = T[i << 1].sum + T[(i << 1) ^ 1].sum; if (T[i].sum >= MOD) T[i].sum -= MOD;}
}
}
void opr1(int l, int r, ll c)
{
int k, l0 = l | n, r0 = r | n, i, j, lch, rch; ll c0;
for (k=s-1; k&&(i=l0>>k)==r0>>k; k--) dm(i);
for (int k0=k; ((i=l0>>k0)<<k0)^l0; k0--) dm(i);
for (int k0=k; (((i=r0>>k0)<<k0)|((1<<k0)-1))^r0; k0--) dm(i);
r0++; int l1;
for (; l1=l0, (l0+=l0&-l0)<=r0; ) {
i = l1 / (l1 & -l1);
T[i].mr1 += c; if (T[i].mr1 >= MOD) T[i].mr1 -= MOD; T[i].sum = (T[i].sum + (c << T[i].len)) % MOD;
}
int _ = (l1 & -l1) >> 1, z = __builtin_ctz(_);
for (; l1<r0; _>>=1, z--) if (l1 + _ <= r0) {
i = l1 >> z;
T[i].mr1 += c; if (T[i].mr1 >= MOD) T[i].mr1 -= MOD; T[i].sum = (T[i].sum + (c << T[i].len)) % MOD;
l1 += _;
}
l0 = l | n; r0 = r | n;
for (k=0; (i=l0>>k)^(j=r0>>k); k++) {
if ((i << k) ^ l0) {T[i].sum = T[i << 1].sum + T[(i << 1) ^ 1].sum; if (T[i].sum >= MOD) T[i].sum -= MOD;}
if (((j << k) | ((1 << k) - 1)) ^ r0) {T[j].sum = T[j << 1].sum + T[(j << 1) ^ 1].sum; if (T[j].sum >= MOD) T[j].sum -= MOD;}
}
for (; k<s; k++) {
i = l0 >> k;
if ((i << k) ^ l0 || ((j << k) | ((1 << k) - 1)) ^ r0) {T[i].sum = T[i << 1].sum + T[(i << 1) ^ 1].sum; if (T[i].sum >= MOD) T[i].sum -= MOD;}
}
}
void opr2(int l, int r)
{
res = 0;
int k, l0 = l | n, r0 = r | n, i, j, lch, rch; ll c0;
for (k=s-1; k&&(i=l0>>k)==r0>>k; k--) dm(i);
for (int k0=k; ((i=l0>>k0)<<k0)^l0; k0--) dm(i);
for (int k0=k; (((i=r0>>k0)<<k0)|((1<<k0)-1))^r0; k0--) dm(i);
r0++; int l1, r1;
for (; l1=l0, (l0+=l0&-l0)<=r0; ) {
i = l1 / (l1 & -l1); res += T[i].sum;
}
int _ = (l1 & -l1) >> 1, z = __builtin_ctz(_);
for (; l1<r0; _>>=1, z--) if (l1 + _ <= r0) {
i = l1 >> z; res += T[i].sum;
l1 += _;
}
res %= MOD;
}
int main()
{
int n0 = get_int(); MOD = get_int(); re(i, n0) A[i] = get_int(); for (n=1, s=0; n<n0; n<<=1, s++) ; s++;
prepare(); int M, _, l, r, c; M = get_int();
re(i, M) {
_ = get_int(); l = get_int(); r = get_int(); l--; r--;
if (_ == 1) {
c = get_int() % MOD; opr0(l, r, c);
} else if (_ == 2) {
c = get_int() % MOD; opr1(l, r, c);
} else {
opr2(l, r); printf("%d\n", (int) res);
}
}
return 0;
}