线段树

概念

线段树是一种用于解决区间更新和区间查询问题的高效数据结构

  • 这里我把区间更新分成两种(upd函数)
    • 区间添加更新,给这个区间的子数组都添加上某个值
    • 区间覆盖更新,把这个区间的子数组都赋值为某个值
  • 而区间查询也大概分为两种(query函数)
    • 区间和查询,求这个区间的数组的元素值之和
    • 区间最大值查询,求这个区间的数组的元素值的最大值

数组无lazy线段树

模板:区间添加更新,区间和查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class SegmentTree {
int[] nums;
int len;
int[] tree;
public SegmentTree(int[] nums) {
this.nums = nums;
this.len = nums.length;
tree = new int[4 * len];
build(1, 0, len - 1);
}
void upd(int idx, int l, int r, int tl, int tr, int val) {
if (tl <= l && r <= tr) {
tree[idx] = val;
if (l != r) { // 没有用lazy数组的情况,就需要给所有子区间更新
int m = l + (r - l) / 2;
upd(2 * idx, l, m, tl, tr, val);
upd(2 * idx + 1, m + 1, r, tl, tr, val);
}
} else {
int m = l + (r - l) / 2;
if (tl <= m) upd(2 * idx, l, m, tl, tr, val);
if (m + 1 <= tr) upd(2 * idx + 1, m + 1, r, tl, tr, val);
tree[idx] = tree[2 * idx] + tree[2 * idx + 1];
}
}
int query(int idx, int l, int r, int tl, int tr) {
if (tl <= l && r <= tr) {
return tree[idx];
} else {
int m = l + (r - l) / 2, res = 0;
if (tl <= m) res += query(2 * idx, l, m, tl, tr);
if (m + 1 <= tr) res += query(2 * idx + 1, m + 1, r, tl, tr);
return res;
}
}
void build(int idx, int l, int r) {
if (l == r) {
tree[idx] = nums[l];
} else {
int m = l + (r - l) / 2;
build(2 * idx, l, m);
build(2 * idx + 1, m + 1, r);
tree[idx] = tree[2 * idx] + tree[2 * idx + 1];
}
}
}

数组lazy线段树

模板:区间添加更新,区间和查询(未验证)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class SegmentTree {
int[] tree;
int[] lazy;
int[] nums;
int n;
public SegmentTree(int[] nums) {
this.nums = nums;
this.n = nums.length;
this.tree = new int[4 * n];
this.lazy = new int[4 * n];
build(1, 0, n - 1);
}
void pushDown(int idx, int l, int r) {
if (lazy[idx] != 0) {
int m = l + (r - l) / 2;
tree[2 * idx] += lazy[idx] * (m - l + 1);
tree[2 * idx + 1] += lazy[idx] * (r - m);
lazy[2 * idx] += lazy[idx];
lazy[2 * idx + 1] += lazy[idx];
lazy[idx] = 0;
}
}
void pushUp(int idx) {
tree[idx] = tree[2 * idx] + tree[2 * idx + 1];
}
void update(int idx, int l, int r, int tl, int tr, int val) {
if (tl <= l && r <= tr) {
tree[idx] += val * (r - l + 1);
lazy[idx] += val;
} else {
pushDown(idx, l, r);
int m = l + (r - l) / 2;
if (tl <= m) update(2 * idx, l, m, tl, tr, val);
if (m + 1 <= tr) update(2 * idx + 1, m + 1, r, tl, tr, val);
pushUp(idx);
}
}
int query(int idx, int l, int r, int tl, int tr) {
if (tl <= l && r <= tr) {
return tree[idx];
} else {
pushDown(idx, l, r);
int m = l + (r - l) / 2;
int res = 0;
if (tl <= m) res += query(2 * idx, l, m, tl, tr);
if (m + 1 <= tr) res += query(2 * idx + 1, m + 1, r, tl, tr);
return res;
}
}
void build(int idx, int l, int r) {
if (l == r) {
tree[idx] = nums[l];
} else {
int m = l + (r - l) / 2;
build(2 * idx, l, m);
build(2 * idx + 1, m + 1, r);
pushUp(idx);
}
}
}

指针lazy线段树

对于有些题目值域过大,我们无法直接使用空间大小固定为 4×n 的常规线段树,就必须采用指针(动态开点)的方式了

模板:区间添加更新,区间和查询(未经题目验证)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class SegmentTree {
class Node {
Node l, r;
int val, lazy;
}
Node root = new Node();
void pushDown(Node cur, int l, int r) {
if (cur.l == null) cur.l = new Node();
if (cur.r == null) cur.r = new Node();
int m = l + (r - l) / 2;
cur.l.val += cur.lazy * (m - l + 1);
cur.r.val += cur.lazy * (r - m);
cur.l.lazy += cur.lazy;
cur.r.lazy += cur.lazy;
cur.lazy = 0;
}
void pushUp(Node cur) {
cur.val = cur.l.val + cur.r.val;
}
void upd(Node cur, int l, int r, int tl, int tr, int val) {
if (tl <= l && r <= tr) {
cur.val += val * (r - l + 1);
cur.lazy += val;
} else {
pushDown(cur, l, r);
int mid = l + (r - l) / 2;
if (tl <= mid) upd(cur.l, l, mid, tl, tr, val);
if (mid < tr) upd(cur.r, mid + 1, r, tl, tr, val);
pushUp(cur);
}
}
int query(Node cur, int l, int r, int tl, int tr) {
if (tl <= l && r <= tr) {
return cur.val;
} else {
pushDown(cur, l, r);
int mid = l + (r - l) / 2;
int res = 0;
if (tl <= mid) res += query(cur.l, l, mid, tl, tr);
if (mid < tr) res += query(cur.r, mid + 1, r, tl, tr);
pushUp(cur);
return res;
}
}
void build(Node cur, int l, int r, int[] nums) {
if (l == r) {
cur.val = nums[l];
} else {
int mid = l + (r - l) / 2;
pushDown(cur, l, r);
build(cur.l, l, mid, nums);
build(cur.r, mid + 1, r, nums);
pushUp(cur);
}
}
}