1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
| #include <cstdio> #include <cstring> #include <vector> using namespace std; struct edge { int v, nxt; } e[200001]; struct binary_indexed_tree { int n; vector<int> a; void init(int size) { n = size + 10, a.clear(), a.resize(n); } void add(int p, int t) { if (p) for (int i = p; i < n; i += i & -i) a[i] += t; } int sum(int p) { if (p <= 0) return 0; int ret = 0; for (int i = min(p, n - 1); i; i -= i & -i) ret += a[i]; return ret; } } tree[200001]; int n, q, nume, tot, top, root, h[100001], w[100001]; int size[100001], f[100001], up[100001], id[100001][2]; bool vis[100001]; vector<int> dist[100001]; void add_edge(int u, int v) { e[++nume].v = v, e[nume].nxt = h[u], h[u] = nume; e[++nume].v = u, e[nume].nxt = h[v], h[v] = nume; } void get_root(int t, int fa) { size[t] = 1, f[t] = 0; for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v] && e[i].v != fa) { get_root(e[i].v, t); size[t] += size[e[i].v], f[t] = max(f[t], size[e[i].v]); } } f[t] = max(f[t], tot - size[t]); if (f[t] < f[root]) root = t; } void get_dist(int t, int fa, int depth, int flag) { if (flag) dist[t].push_back(depth); else { tree[id[root][0]].add(depth, w[t]); if (dist[t].size() > 1) tree[id[root][1]].add(dist[t][dist[t].size() - 2], w[t]); } for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v] && e[i].v != fa) get_dist(e[i].v, t, depth + 1, flag); } } void solve(int t) { vis[t] = true; for (int i = h[t]; i; i = e[i].nxt) if (!vis[e[i].v]) get_dist(e[i].v, t, 1, 0); for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v]) { root = 0, tot = size[e[i].v], get_root(e[i].v, t), up[root] = t; tree[id[root][0] = ++top].init(size[e[i].v]); tree[id[root][1] = ++top].init(size[e[i].v]); tree[id[root][1]].add(dist[root][dist[root].size() - 1], w[root]); get_dist(root, 0, 0, 1); solve(root); } } } void modify(int t, int d) { int p = t, top = dist[t].size() - 1; while (p) { if (top >= 0) tree[id[p][0]].add(dist[t][top], d); if (top > 0) tree[id[p][1]].add(dist[t][top - 1], d); p = up[p], --top; } } int ask(int t, int d) { int p = t, ret = 0, top = dist[t].size() - 1; while (p) { int len; if (top >= 0) len = d - dist[t][top]; if (top >= 0 && len >= 0) ret += tree[id[p][0]].sum(len) + w[p]; if (top > 0) len = d - dist[t][top - 1]; if (top > 0 && len >= 0 && up[p]) ret -= tree[id[p][1]].sum(len); p = up[p], --top; } return ret; } int main() { while (scanf("%d%d", &n, &q) != EOF) { nume = top = 0; memset(up, 0, sizeof(up)); memset(h, 0, sizeof(h)); memset(vis, 0, sizeof(vis)); for (int i = 1; i <= n; ++i) { scanf("%d", &w[i]); dist[i].clear(); } for (int i = 1; i < n; ++i) { int u, v; scanf("%d%d", &u, &v), add_edge(u, v); } root = 0, tot = f[0] = n, get_root(1, 0); tree[id[root][0] = ++top].init(size[1]); tree[id[root][1] = ++top].init(size[1]); get_dist(root, 0, 0, 1); solve(root); while (q--) { char oper = ' '; while (oper != '!' && oper != '?') scanf("%c", &oper); int v, d; scanf("%d%d", &v, &d); if (oper == '!') modify(v, d - w[v]), w[v] = d; else printf("%d\n", ask(v, d)); } } return 0; }
|