题目描述
Bobo has a tree, whose vertices are conveniently labeled by $1, 2, \ldots, n$. At the very begining, the $i$-th vertex is assigned with weight $w_i$.
There are $q$ operations. Each operations are of the following $2$ types:
- change the weight of vertex $v$ into $x$ (denoted as “! $v$ $x$”),
- ask the total weight of vertices whose distance are no more than $d$ away from vertex $v$ (denoted as “? $v$ $d$”).
Note that the distance between vertex $u$ and $v$ is the number of edges on the shortest path between them.
题意概述
给定一棵有$n$个节点的树,第$i$个节点的权值为$w_i$。有两种操作:①将节点$v$的权值变成$x$;②询问与节点$v$的距离不超过$d$的所有节点的权值和。共有$q$次操作。
数据范围:$1 \le n, q \le 10^5, \; 0 \le w_i \le 10^4, \; 0 \le x \le 10^4, \; 0 \le d \le n$。
算法分析
对于我们枚举的某一个重心$x$,从它子树的重心向它连一条“虚”边,这样就形成了一棵由重心相连构成的“重心树”。由重心的性质可得,这棵树最多只有$O(\log n)$层。
如图,黑色表示原树上的边,红色表示重心树上的“虚”边。
对于每个节点,我们用树状数组存下它在“重心树”上的子树节点到它的距离为$p$(在原树上的距离)的权值和。当我们询问到节点$v$的距离不超过$d$的节点的权值和时,答案就等于$v$的树状数组中$d$的前缀和,再加上“重心树”上$v$的祖先节点$u$的树状数组中$d-dist_{u, v}$的前缀和。可以发现,$u$在“重心树”上包含$v$的子树中的节点被重复计算了。
如图,在计算到节点$3$的距离不超过$3$的节点的权值和时,计算了节点$3$子树中的节点$1, 3, 6, 8$;在计算到节点$3$的祖先节点$2$的距离不超过$3-2=1$的节点的权值和时,节点$1$又被计算了一次。
根据容斥原理,只需对每个节点再用一个树状数组记录下它在“重心树”上的子树节点到它在“重心树”上的父节点的距离为$p$(在原树上的距离)的权值和,每次计算时减去这个树状数组中$d-dist_{u, v}$的前缀和,就得到了正确答案。
修改节点$v$的权值时,只需更新“重心树”上$v$所有祖先节点的树状数组即可。
树上两点间的距离可以通过倍增求LCA得到,不过有更简便的方法。由于每个节点只会被$O(\log n)$个重心搜到,因此可以存下每个节点到其所有祖先重心节点的距离。
在“重心树”上操作的时间复杂度为$O(\log n)$,树状数组的时间复杂度为$O(\log n)$,总时间复杂度为$O(q\log^2n)$。
“重心树”上每一层节点的树状数组空间复杂度为$O(n)$,有$O(\log n)$层,每个节点上要存$O(\log n)$个距离,总空间复杂度为$O(n\log n)$。
代码
#include <cstdio> #include <cstring> #include <vector> using namespace std; struct edge { int v, nxt; } e[200001]; struct binary_indexed_tree { int n; vector<int> a; void init(int size) { n = size + 10, a.clear(), a.resize(n); } void add(int p, int t) { if (p) for (int i = p; i < n; i += i & -i) a[i] += t; } int sum(int p) { if (p <= 0) return 0; int ret = 0; for (int i = min(p, n - 1); i; i -= i & -i) ret += a[i]; return ret; } } tree[200001]; int n, q, nume, tot, top, root, h[100001], w[100001]; int size[100001], f[100001], up[100001], id[100001][2]; bool vis[100001]; vector<int> dist[100001]; void add_edge(int u, int v) { e[++nume].v = v, e[nume].nxt = h[u], h[u] = nume; e[++nume].v = u, e[nume].nxt = h[v], h[v] = nume; } void get_root(int t, int fa) { size[t] = 1, f[t] = 0; for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v] && e[i].v != fa) { get_root(e[i].v, t); size[t] += size[e[i].v], f[t] = max(f[t], size[e[i].v]); } } f[t] = max(f[t], tot - size[t]); if (f[t] < f[root]) root = t; } void get_dist(int t, int fa, int depth, int flag) { if (flag) dist[t].push_back(depth); else { tree[id[root][0]].add(depth, w[t]); if (dist[t].size() > 1) tree[id[root][1]].add(dist[t][dist[t].size() - 2], w[t]); } for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v] && e[i].v != fa) get_dist(e[i].v, t, depth + 1, flag); } } void solve(int t) { vis[t] = true; for (int i = h[t]; i; i = e[i].nxt) if (!vis[e[i].v]) get_dist(e[i].v, t, 1, 0); for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v]) { root = 0, tot = size[e[i].v], get_root(e[i].v, t), up[root] = t; tree[id[root][0] = ++top].init(size[e[i].v]); tree[id[root][1] = ++top].init(size[e[i].v]); tree[id[root][1]].add(dist[root][dist[root].size() - 1], w[root]); get_dist(root, 0, 0, 1); solve(root); } } } void modify(int t, int d) { int p = t, top = dist[t].size() - 1; while (p) { if (top >= 0) tree[id[p][0]].add(dist[t][top], d); if (top > 0) tree[id[p][1]].add(dist[t][top - 1], d); p = up[p], --top; } } int ask(int t, int d) { int p = t, ret = 0, top = dist[t].size() - 1; while (p) { int len; if (top >= 0) len = d - dist[t][top]; if (top >= 0 && len >= 0) ret += tree[id[p][0]].sum(len) + w[p]; if (top > 0) len = d - dist[t][top - 1]; if (top > 0 && len >= 0 && up[p]) ret -= tree[id[p][1]].sum(len); p = up[p], --top; } return ret; } int main() { while (scanf("%d%d", &n, &q) != EOF) { nume = top = 0; memset(up, 0, sizeof(up)); memset(h, 0, sizeof(h)); memset(vis, 0, sizeof(vis)); for (int i = 1; i <= n; ++i) { scanf("%d", &w[i]); dist[i].clear(); } for (int i = 1; i < n; ++i) { int u, v; scanf("%d%d", &u, &v), add_edge(u, v); } root = 0, tot = f[0] = n, get_root(1, 0); tree[id[root][0] = ++top].init(size[1]); tree[id[root][1] = ++top].init(size[1]); get_dist(root, 0, 0, 1); solve(root); while (q--) { char oper = ' '; while (oper != '!' && oper != '?') scanf("%c", &oper); int v, d; scanf("%d%d", &v, &d); if (oper == '!') modify(v, d - w[v]), w[v] = d; else printf("%d\n", ask(v, d)); } } return 0; }