算法:线段树

1. 介绍

给你一段长度为N的序列A[],求:

  • 单点修改,单点查询
  • 单点修改,区间查询
  • 区间修改,单点查询
  • 取键修改,区间查询

1.1 线段树的建立与查询

  • 线段树将一个区间划分为一些单元区间,每个单元区间对应线段树中的一个叶节点
  • 根节点对应区间为$ [1, N] $
  • 对于线段树中的每一个非叶子节点$ [a,b] $,左儿子表示的区间为$ [a, (a+b)/2] $,右儿子表示的区间为$ [(a+b)/2+1, b] $。最后的叶子节点数目为$ N $
  • $ (a+b)/2 $可以写作$ (a+b)>>1 $
  • 堆式存储:父亲节点的编号为$ i $,则左儿子的编号为$ i\times 2 $,右儿子的编号为$ i \times 2+1 $
  • 可分别写作$ i << 1 $,$ i<<1|1 $

image-20210316205140853

1.2 代码实现

建树:

1
2
3
4
5
6
7
8
9
10
11
12
// o为当前节点编号,l和r分别是当前节点的左右端点
// 原数据为a[],建的树存放在sum[]
void Build_Tree(int o, int l, int r) {
if (l == r) {
sum[o] = a[l];
return;
}
int mid = (l+r) >> 1;
Build_Tree(o << 1, l, mid);
Build_Tree(o << 1 | 1, mid + 1, r);
sum[o] = sum[o << 1] + sum[o << 1 | 1];
}

单点修改:单点修改复杂度为$ O(log n) $,单点查询类似

1
2
3
4
5
6
7
8
9
10
11
12
// a[x] = y
// 自顶向下修改
void update(int o, int l, int r) {
if (l == r) {
sum[o] = a[l];
return;
}
int mid = (l + r) >> 1;
if (x <= mid) update(o << 1, l, mid);
else update(o << 1 | 1, mid + 1, r);
sum[o] = sum[o << 1] + sum[o << 1 | 1];
}

区间查询:可以通过访问不超过$ 2*log N $个区间来获取任意区间的答案

1
2
3
4
5
6
7
8
9
10
11
// 查询区间[x, y]
int ans = 0;
void Query(int o, int l, int r) {
if (x <= l && y >= r) {
ans += sum[o];
return;
}
int mid = (l + r) >> 1;
if (x <= mid) Query(o << 1, l, mid);
if (mid < y) Query(o << 1 | 1, mid + 1, r);
}

1.3 Lazy-Tag

应用于区间修改

  • Lazy-Tag记录的是每一个线段树节点的变化值
  • 当这部分区间的一致性被破坏时,就将这个变化值传递给子区间
  • 每个节点存一个Tag值,表示该区间进行的变化
  • 每访问到一个节点时,Tag下传
  • 需要使用满二叉树

image-20210317003909247

image-20210317003923809

蓝色是要把信息及时维护的节点,红色是本次区间修改操作Lazy-Tag下传停止的位置

多个Lazy Tag怎么办?

  • 考虑打标记运算的优先级

1.4 代码实现

  • 每个节点都要记录统计量和Lazy Tag
  • 相关函数传递三个参数(int o, int l, int r)
  • 数组大小:将$ N $补成2的幂次,再×2,或者直接开4倍空间
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
typedef long long ll;
const int DAT_SIZE = 1 << 17;

// 线段树,data[]存储lazy-tag,datb[]存储线段树
ll data[2 * DAT_SIZE - 1], datb[2 * DAT_SIZE - 1];

// 初始化
void init(int n_) {
// 将n_扩大到2的幂,再×2
n = 1;
while (n < n_) n *= 2;

for (int i = 0; i < 2 * n - 1; i++) data[i] = 0; datb[i] = 0;
}

// 对区间[a, b]同时加x
// o是当前节点编号,l和r分别是区间左右端点
void add(int a, int b, int x, int o, int l, int r) {
if (a <= l && r <= b) {
data[o] += x;
return;
}
int mid = (l + r) >> 1;
datb[o] += (min(b, r) - max(a, l) + 1) * x;
if (a <= mid) add(a, b, x, o << 1, l, mid);
if (mid < b) add(a, b, x, o << 1 | 1, mid + 1, r);
}

// 计算[a, b]的和
ll sum(int a, int b, int o, int l, int r) {
int ans = 0;
if (a > r || b < l) return ans;
if (a <= l && b >= r) {
ans += datb[o] + data[o] * (r - l + 1);
return ans;
}
else {
ans += (min(b, r) - max(a, l) + 1) * data[o];
int mid = (l + r) >> 1;
if (a <= mid) ans += sum(a, b, o << 1, l, mid);
if (b > mid) ans += sum(a, b, o << 1 | 1, mid + 1, r);
return ans;
}
}

2. 应用

CCF认证202012-5

两种操作:区间加值,区间乘值。然后查询。

旋转的操作没考虑。

拿了40分,预期应该拿60分。

Reference:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// 线段树(加法+乘法)
// 40分(本应该过60分)
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const ll maxn = 1e7 + 5;
const ll mod = 1e9 + 7;

class segmentTree {
private:
vector<ll> data, add, mul, rot;
public:
void init(int n)
{
data.resize(n << 2);
add.resize(n << 2);
mul.resize(n << 2);
rot.resize(n << 2);
}
void pushup(int o)
{
data[o] = (data[o << 1] + data[o << 1 | 1]) % mod;
}

void pushdown(int o, int l, int r)
{
int mid = (l + r) >> 1;
data[o << 1] = (data[o << 1] * mul[o] + ll(mid - l + 1) * add[o]) % mod;
data[o << 1 | 1] = (data[o << 1 | 1] * mul[o] + ll(r - mid) * add[o]) % mod;
mul[o << 1] = (mul[o << 1] * mul[o]) % mod;
mul[o << 1 | 1] = (mul[o << 1 | 1] * mul[o]) % mod;
add[o << 1] = (add[o << 1] * mul[o] + add[o]) % mod;
add[o << 1 | 1] = (add[o << 1 | 1] * mul[o] + add[o]) % mod;
mul[o] = 1;
add[o] = 0;
}

void build(int o, int l, int r)
{
add[o] = 0;
mul[o] = 1;
rot[o] = 0;
if (l == r) {
data[o] = 0;
return;
}
int mid = (l + r) >> 1;
build(o << 1, l, mid);
build(o << 1 | 1, mid + 1, r);
pushup(o);
}

// 对区间[m, n]同时加k
void add_update(int o, int l, int r, int m, int n, ll k)
{
if (m <= l && r <= n) {
add[o] = (add[o] + k) % mod;
data[o] = (data[o] + k * ll(r - l + 1)) % mod;
return;
}
pushdown(o, l, r);
int mid = (l + r) >> 1;
if (m <= mid) add_update(o << 1, l , mid, m, n, k);
if (mid < n) add_update(o << 1 | 1, mid + 1, r, m , n, k);
pushup(o);
}

// 对区间[m, n]同时乘以k
void mul_update(int o, int l, int r, int m, int n, ll k)
{
if (m <= l && r <= n) {
mul[o] = (mul[o] * k) % mod;
data[o] = (data[o] * k) % mod;
add[o] = (add[o] * k) % mod;
return;
}
pushdown(o, l, r);
int mid = (l + r) >> 1;
if (m <= mid) mul_update(o << 1, l, mid, m, n, k);
if (n > mid) mul_update(o << 1 | 1, mid + 1, r, m, n, k);
pushup(o);
}

// 对区间[m, n]旋转
void rot_update(int o, int l, int r, int m, int n)
{

}

// 对区间[m, n]查询
ll query(int o, int l, int r, int m, int n)
{
if (m <= l && r <= n) return data[o];
ll ans = 0;
int mid = (l + r) >> 1;
pushdown(o, l, r);
if (m <= mid) ans = (ans + query(o << 1, l, mid, m, n)) % mod;
if (mid < n) ans = (ans + query(o << 1 | 1, mid + 1, r, m, n)) % mod;
return ans;
}
};


int main()
{
int n, m;
scanf("%d%d", &n, &m);
segmentTree x = segmentTree();
segmentTree y = segmentTree();
segmentTree z = segmentTree();
x.init(n);
y.init(n);
z.init(n);
x.build(1, 1, n);
y.build(1, 1, n);
z.build(1, 1, n);
for (int i = 0; i < m; i++) {
ll num, l, r, a, b, c, k;
scanf("%lld%lld%lld", &num, &l, &r);
if (num == 1) {
scanf("%lld%lld%lld", &a, &b, &c);
x.add_update(1, 1, n, l, r, a);
y.add_update(1, 1, n, l, r, b);
z.add_update(1, 1, n, l, r, c);
}
else if (num == 2) {
scanf("%lld", &k);
x.mul_update(1, 1, n, l, r, k);
y.mul_update(1, 1, n, l, r, k);
z.mul_update(1, 1, n, l, r, k);
}
else if (num == 3) {
// x.rot_update(1, 1, n, l, r);
continue;
}
else if (num == 4) {
ll sumx = 0, sumy = 0, sumz = 0;
sumx = x.query(1, 1, n, l, r);
sumy = y.query(1, 1, n, l, r);
sumz = z.query(1, 1, n, l, r);
printf("%lld\n", (sumx * sumx % mod + sumy * sumy % mod + sumz * sumz % mod) % mod);
}
}

return 0;
}

Reference