【模板】线段树1(区间修改,区间查询)
原题链接
题目详情
数据范围
- \(1≤N,M≤10^5\)
- 数列中任意元素的和在 \([-2^{63}, 2^{63})\)内
算法与思路
标记的含义:本区间已经被更新过了,但是子区间却没有被更新过,被更新的信息是什么。
线段树节点 node
首先此题涉及区间修改,我们需要加入懒标记记为 \(add\),表示当前节点区间已经加上了 \(add\),但是子树未被更新。
其次,我们还需要一个数 \(sum\) 来表示当前区间和。因此节点node即为:
| typedef long long LL;
struct node {
int l, r;
LL sum, add;
} tr[N << 2];
|
pushup操作
首先 \(pushup\) 操作是向上传递,我们并不需要把懒标记传给祖宗节点。所以只需要维护 \(sum\) 即可。
\[root.sum=left.sum+right.sum;\]
pushup
函数:
| void pushup(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
|
pushdown操作
显然,父区间是包含子区间的,也就是对于父区间的标记和子区间是有联系的。在大多数情况下,父区间和子区间的标记是 相同的 。因此,我们可以由父区间的标记推算出子区间应当是什么标记。
在每次修改和查询操作时,我们需要将当前节点的懒标记清空并传递给子节点。
对于子节点,\(add\) 只需加上父节点的 \(add\),而 \(sum\) 是需要区间内每个数都加上 \(add\),因此转移方程为:
\[son.add+=root.add;\]
\[son.sum+=(son.r-son.l+1)*root.add;\]
pushdown
函数:
1
2
3
4
5
6
7
8
9
10
11
12 | void pushdown(int u) {
auto &root = tr[u];
auto &left = tr[u << 1];
auto &right = tr[u << 1 | 1];
if (root.add) {
left.add += root.add;
left.sum += 1ll * (left.r - left.l + 1) * root.add;
right.add += root.add;
right.sum += 1ll * (right.r - right.l + 1) * root.add;
root.add = 0; // 清空lazy tag
}
}
|
建树操作
建树时,叶子节点的 \(sum\) 即为当前数列中的数,懒标记 \(add\) 初始为 \(0\)。
注意:祖宗节点每次建立完左右子树后,需要 \(pushup\) 操作一次,修改当前节点的值。
build
函数:
| void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, w[r], 0};
else {
tr[u].l = l, tr[u].r = r;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
|
修改操作
此处的修改操作主要针对与懒标记,对于目标修改区间 \([l,r]\):
- 如果当前节点区间在 \([l,r]\) 内,修改当前节点;
- 如果当前节点区间范围超过 \([l,r]\),分别递归到左右子节点修改。
注意:每次做递归修改之前,需要将父节点的懒标记 \(pushdown\) 到左右子节点中!!!
modify
函数:
1
2
3
4
5
6
7
8
9
10
11
12 | void modify(int u, int l, int r, int v) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].sum += 1ll * (tr[u].r - tr[u].l + 1) * v;
tr[u].add += v;
} else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
}
|
查询操作
对于每次查询,我们依旧需要把懒标记下传,防止最终答案是更新之前的值。
同样,对于目标查询区间 \([l,r]\):
- 当前节点区间在查询区间内,返回 \(root.sum\);
- 当前节点区间大于查询区间,返回 \(left.sum+right.sum\)。
注意:每次做递归修改之前,需要将父节点的懒标记 \(pushdown\) 到左右子节点中!!!
query
函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 | LL query(int u, int l, int r) {
// 1. 在目标区间内
// l--------------r
// Tl----Tr
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
// 2. 在目标区间外
// l-----r
// Tl---m---Tr
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
|
代码(无注释)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84 | #include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m;
LL w[N];
struct node {
int l, r;
LL sum, add;
} tr[N << 2];
void pushup(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u) {
auto &root = tr[u];
auto &left = tr[u << 1];
auto &right = tr[u << 1 | 1];
if (root.add) {
left.add += root.add;
left.sum += 1ll * (left.r - left.l + 1) * root.add;
right.add += root.add;
right.sum += 1ll * (right.r - right.l + 1) * root.add;
root.add = 0;
}
}
void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, w[r], 0};
else {
tr[u].l = l, tr[u].r = r;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, LL v) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].add += v;
tr[u].sum += 1ll * (tr[u].r - tr[u].l + 1) * v;
} else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
}
LL query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> w[i];
build(1, 1, n);
int op, l, r;
while (m--) {
cin >> op >> l >> r;
if (op == 1) {
LL x; cin >> x;
modify(1, l, r, x);
} else cout << query(1, l, r) << endl;
}
return 0;
}
|
运行结果