P3384 【模板】树链剖分
Solution
树剖+线段树模板题,注意取模
Code
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
using LL = long long;
constexpr int MAXN = 4e5 + 10;
vector<int>tree[MAXN];
int fa[MAXN], hson[MAXN], dep[MAXN], siz[MAXN];
int top[MAXN], rnk[MAXN], dfn[MAXN];
LL w[MAXN];
int n, p, cnt = 0;
//dfs1()记录每个结点的父节点(fa)、深度(dep)、子树大小(siz)、重子节点(hson)
void dfs1(int o) {
hson[o] = -1, siz[o] = 1;
for (auto i : tree[o]) {
//跳过已处理过的i节点
if (!dep[i]) {
dep[i] = dep[o] + 1;//更新i节点的深度
fa[i] = o;//更新i节点的father
dfs1(i);
siz[o] += siz[i];//更新o的子树大小
if (hson[o] == -1 || siz[i] > siz[hson[o]]) hson[o] = i;//更新o的重儿子
}
}
}
//dfs2()记录链顶(top)、重边优先遍历时的dfs序(dfn)和dfs序对应的节点编号(rnk)
void dfs2(int o, int head) {
top[o] = head;
cnt++;
dfn[o] = cnt;
rnk[cnt] = o;
if (hson[o] == -1) return;//叶子节点返回
dfs2(hson[o], head); //优先dfs重儿子,可以保证同一条重链上的点DFS序连续
for (auto i : tree[o])if (i != hson[o] && i != fa[o]) dfs2(i, i);//然后dfs轻儿子
}
class segTree {
public:
void build(int o, int s, int t) {
laz[o] = 0;
if (s == t) { sum[o] = w[rnk[s]]; return; }//here
int mid = (s + t) / 2;
build(lson(o), s, mid);
build(rson(o), mid + 1, t);
push_up(o);
}
void update(int l, int r, int k) {
int fl = top[l], fr = top[r];
while (fl != fr) {
if (dep[fl] >= dep[fr])update(1, 1, n, dfn[fl], dfn[l], k), l = fa[fl];
else update(1, 1, n, dfn[fr], dfn[r], k), r = fa[fr];
fl = top[l], fr = top[r];
}
if (dfn[l] < dfn[r])update(1, 1, n, dfn[l], dfn[r], k);
else update(1, 1, n, dfn[r], dfn[l], k);
}
LL query_sum(int l, int r) {
LL ret = 0;
int fl = top[l], fr = top[r];
while (fl != fr) {
if (dep[fl] >= dep[fr])ret += querysum(1, 1, n, dfn[fl], dfn[l]) % p, l = fa[fl];
else ret += querysum(1, 1, n, dfn[fr], dfn[r]) % p, r = fa[fr];
fl = top[l], fr = top[r];
}
if (l != r) {
if (dfn[l] < dfn[r])ret += querysum(1, 1, n, dfn[l], dfn[r]) % p;
else ret += querysum(1, 1, n, dfn[r], dfn[l]) % p;
}
else ret += querysum(1, 1, n, dfn[l], dfn[r]) % p;
return ret;
}
LL laz[MAXN], sum[MAXN];
int lson(int x) { return x << 1; }
int rson(int x) { return x << 1 | 1; }
void push_up(int x) { sum[x] = (sum[lson(x)] + sum[rson(x)]) % p; }
void push_down(int x, int l, int r) {
int mid = (l + r) / 2;
if (laz[x] && l != r) {
laz[lson(x)] += laz[x];
laz[rson(x)] += laz[x];
sum[lson(x)] += laz[x] * (mid - l + 1);
sum[rson(x)] += laz[x] * (r - mid);
laz[x] = 0;
}
}
void update(int o, int s, int t, int l, int r, LL k) {
if (l > t || r < s)return;
else if (l <= s && t <= r) { sum[o] += (k * (t - s + 1)); sum[o] %= p; laz[o] += k; laz[o] %= p; return; }
push_down(o, s, t);
int mid = (s + t) / 2;
update(lson(o), s, mid, l, r, k);
update(rson(o), mid + 1, t, l, r, k);
push_up(o);
}
LL querysum(int o, int s, int t, int ql, int qr) {
if (ql > t || qr < s)return 0;
else if (ql <= s && t <= qr)return sum[o] % p;
push_down(o, s, t);
int mid = (s + t) / 2;
return (querysum(lson(o), s, mid, ql, qr) + querysum(rson(o), mid + 1, t, ql, qr)) % p;
}
};
int main(int argc, char **argv) {
int q, x, y, z, m, r;
cin >> n >> m >> r >> p;
for (int i = 1; i <= n; i++)cin >> w[i];
for (int i = 0; i < n - 1; i++) {
cin >> x >> y;
tree[x].push_back(y);
tree[y].push_back(x);
}
dep[r] = 1;
dfs1(r);
dfs2(r, r);
segTree ans;
ans.build(1, 1, n);
while (m--) {
cin >> q;
if (q == 1) {
cin >> x >> y >> z;
if (x > y)swap(x, y);
ans.update(x, y, z%p);
}
else if (q == 2) {
cin >> x >> y;
cout << ans.query_sum(x, y) % p << endl;
}
else if (q == 3) {
cin >> x >> z;
ans.update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z%p);
}
else {
cin >> x;
cout << ans.querysum(1, 1, n, dfn[x], dfn[x] + siz[x] - 1) % p << endl;
}
}
return 0;
}