Skip to the content.

P3384 【模板】树链剖分

Link

Solution

树剖+线段树模板题,注意取模

Code

#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
using LL = long long;
constexpr int MAXN = 4e5 + 10;

vector<int>tree[MAXN];
int fa[MAXN], hson[MAXN], dep[MAXN], siz[MAXN];
int top[MAXN], rnk[MAXN], dfn[MAXN];
LL w[MAXN];
int n, p, cnt = 0;

//dfs1()记录每个结点的父节点(fa)、深度(dep)、子树大小(siz)、重子节点(hson)
void dfs1(int o) {
    hson[o] = -1, siz[o] = 1;
    for (auto i : tree[o]) {
        //跳过已处理过的i节点
        if (!dep[i]) {
            dep[i] = dep[o] + 1;//更新i节点的深度
            fa[i] = o;//更新i节点的father
            dfs1(i);
            siz[o] += siz[i];//更新o的子树大小
            if (hson[o] == -1 || siz[i] > siz[hson[o]]) hson[o] = i;//更新o的重儿子
        }
    }
}
//dfs2()记录链顶(top)、重边优先遍历时的dfs序(dfn)和dfs序对应的节点编号(rnk)
void dfs2(int o, int head) {
    top[o] = head;
    cnt++;
    dfn[o] = cnt;
    rnk[cnt] = o;
    if (hson[o] == -1) return;//叶子节点返回
    dfs2(hson[o], head);  //优先dfs重儿子,可以保证同一条重链上的点DFS序连续
    for (auto i : tree[o])if (i != hson[o] && i != fa[o]) dfs2(i, i);//然后dfs轻儿子
}

class segTree {
public:
    void build(int o, int s, int t) {
        laz[o] = 0;
        if (s == t) { sum[o] = w[rnk[s]]; return; }//here
        int mid = (s + t) / 2;
        build(lson(o), s, mid);
        build(rson(o), mid + 1, t);
        push_up(o);
    }

    void update(int l, int r, int k) {
        int fl = top[l], fr = top[r];
        while (fl != fr) {
            if (dep[fl] >= dep[fr])update(1, 1, n, dfn[fl], dfn[l], k), l = fa[fl];
            else update(1, 1, n, dfn[fr], dfn[r], k), r = fa[fr];
            fl = top[l], fr = top[r];
        }
        if (dfn[l] < dfn[r])update(1, 1, n, dfn[l], dfn[r], k);
        else update(1, 1, n, dfn[r], dfn[l], k);
    }

    LL query_sum(int l, int r) {
        LL ret = 0;
        int fl = top[l], fr = top[r];
        while (fl != fr) {
            if (dep[fl] >= dep[fr])ret += querysum(1, 1, n, dfn[fl], dfn[l]) % p, l = fa[fl];
            else ret += querysum(1, 1, n, dfn[fr], dfn[r]) % p, r = fa[fr];
            fl = top[l], fr = top[r];
        }
        if (l != r) {
            if (dfn[l] < dfn[r])ret += querysum(1, 1, n, dfn[l], dfn[r]) % p;
            else ret += querysum(1, 1, n, dfn[r], dfn[l]) % p;
        }
        else ret += querysum(1, 1, n, dfn[l], dfn[r]) % p;
        return ret;
    }

    LL laz[MAXN], sum[MAXN];
    int lson(int x) { return x << 1; }
    int rson(int x) { return x << 1 | 1; }
    void push_up(int x) { sum[x] = (sum[lson(x)] + sum[rson(x)]) % p; }
    void push_down(int x, int l, int r) {
        int mid = (l + r) / 2;
        if (laz[x] && l != r) {
            laz[lson(x)] += laz[x];
            laz[rson(x)] += laz[x];
            sum[lson(x)] += laz[x] * (mid - l + 1);
            sum[rson(x)] += laz[x] * (r - mid);
            laz[x] = 0;
        }
    }
    void update(int o, int s, int t, int l, int r, LL k) {
        if (l > t || r < s)return;
        else if (l <= s && t <= r) { sum[o] += (k * (t - s + 1)); sum[o] %= p; laz[o] += k; laz[o] %= p; return; }
        push_down(o, s, t);
        int mid = (s + t) / 2;
        update(lson(o), s, mid, l, r, k);
        update(rson(o), mid + 1, t, l, r, k);
        push_up(o);
    }
    LL querysum(int o, int s, int t, int ql, int qr) {
        if (ql > t || qr < s)return 0;
        else if (ql <= s && t <= qr)return sum[o] % p;
        push_down(o, s, t);
        int mid = (s + t) / 2;
        return (querysum(lson(o), s, mid, ql, qr) + querysum(rson(o), mid + 1, t, ql, qr)) % p;
    }
};

int main(int argc, char **argv) {
    int q, x, y, z, m, r;
    cin >> n >> m >> r >> p;
    for (int i = 1; i <= n; i++)cin >> w[i];
    for (int i = 0; i < n - 1; i++) {
        cin >> x >> y;
        tree[x].push_back(y);
        tree[y].push_back(x);
    }
    dep[r] = 1;
    dfs1(r);
    dfs2(r, r);
    segTree ans;
    ans.build(1, 1, n);
    while (m--) {
        cin >> q;
        if (q == 1) {
            cin >> x >> y >> z;
            if (x > y)swap(x, y);
            ans.update(x, y, z%p);
        }
        else if (q == 2) {
            cin >> x >> y;
            cout << ans.query_sum(x, y) % p << endl;
        }
        else if (q == 3) {
            cin >> x >> z;
            ans.update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z%p);
        }
        else {
            cin >> x;
            cout << ans.querysum(1, 1, n, dfn[x], dfn[x] + siz[x] - 1) % p << endl;
        }
    }
    return 0;
}