题目描述
You are given a tree with $N$ nodes. The tree nodes are numbered from $1$ to $N$. Each node has an integer weight.
We will ask you to perform the following operation:
- $u$ $v$: ask for how many different integers that represent the weight of nodes there are on the path from $u$ to $v$.
题意概述
给定一棵$N$个节点的树,每个节点都有权值。有$M$次询问,每次询问给出$u, v$,求从$u$到$v$的路径上有多少种不同的权值。
数据范围:$1 \le N \le 40000, \; 1 \le M \le 10^5$。
算法分析
统计权值种类数,容易想到莫队,但这是在一棵树上,所以需要稍微做些转化。
求出树的括号序(DFS,访问和离开节点时均记录),那么每个询问都可以转化为对括号序上某个区间的询问。但是,如果区间内的某个节点出现了两次,那么这个节点不应该对答案产生贡献。因此,算法执行时可以用两个数组维护,一个表示节点的出现次数,另一个表示权值的出现次数。
具体来讲,令$s_i, e_i$分别表示访问、离开节点$i$的时间。假设$s_u \le s_v$。若$u$是$v$祖先,那么就相当于询问区间$[s_u, s_v]$;否则,相当于询问区间$[e_u, s_v]$,但此时$u, v$的LCA并没有被计算到答案中,因此需要把它加上。
代码
/* * Trap full -- please empty. */ #include <map> #include <cstdio> #include <cstring> #include <algorithm> template <typename T> void read(T &n) { char c; for (; (c = getchar()) < '0' || c > '9'; ) ; for (n = c - '0'; (c = getchar()) >= '0' && c <= '9'; (n *= 10) += c - '0') ; } typedef int const ic; typedef long long ll; static ic N = 40005; static ic M = 100005; static ic K = 300; std::map <int, int> num; int nume, h[N], w[N], base[N], seq[N << 1]; int tim, st[N], ed[N], dep[N], up[N][16]; int mans, ml, mr, mvis[N], mrec[N], ans[M]; struct Edge { int v, nxt; } e[N << 1]; struct Query { int l, r, lca, id; bool operator < (const Query &q) const { return l / K == q.l / K ? r < q.r : l / K < q.l / K; } } q[M]; void add_edge(ic &u, ic &v) { e[++ nume] = (Edge) { v, h[u] }, h[u] = nume; e[++ nume] = (Edge) { u, h[v] }, h[v] = nume; } void dfs(ic &t, ic &fa) { seq[st[t] = ++ tim] = t, dep[t] = dep[fa] + 1, up[t][0] = fa; for (int i = h[t]; i; i = e[i].nxt) if (e[i].v != fa) dfs(e[i].v, t); seq[ed[t] = ++ tim] = t; } int get_lca(int u, int v) { if (dep[u] > dep[v]) std::swap(u, v); for (int i = 15; ~ i; -- i) if (dep[up[v][i]] >= dep[u]) v = up[v][i]; if (u == v) return u; for (int i = 15; ~ i; -- i) if (up[u][i] != up[v][i]) u = up[u][i], v = up[v][i]; return up[u][0]; } void add(ic &u) { if (mvis[u]) { -- mrec[w[u]], mvis[u] = 0; if (! mrec[w[u]]) -- mans; } else { if (! mrec[w[u]]) ++ mans; ++ mrec[w[u]], mvis[u] = 1; } } void get(ic &l, ic &r) { for (; ml < l;) add(seq[ml ++]); for (; ml > l;) add(seq[-- ml]); for (; mr < r;) add(seq[++ mr]); for (; mr > r;) add(seq[mr --]); } int main() { int n, m; read(n), read(m); for (int i = 1; i <= n; ++ i) { read(w[i]); if (! num.count(w[i])) num[w[i]] = num.size(); w[i] = num[w[i]]; } for (int i = 1, u, v; i < n; ++ i) read(u), read(v), add_edge(u, v); dfs(1, 0); for (int i = 1; i < 16; ++ i) for (int j = 1; j <= n; ++ j) up[j][i] = up[up[j][i - 1]][i - 1]; for (int i = 1, u, v; i <= m; ++ i) { read(u), read(v), q[i].id = i; if (st[u] > st[v]) std::swap(u, v); if (ed[u] > ed[v]) q[i].l = st[u] + 1, q[i].r = st[v], q[i].lca = u; else q[i].l = ed[u], q[i].r = st[v], q[i].lca = get_lca(u, v); } std::sort(q + 1, q + m + 1); mans = ml = mr = mvis[1] = mrec[w[1]] = 1; for (int i = 1; i <= m; ++ i) get(q[i].l, q[i].r), add(q[i].lca), ans[q[i].id] = mans, add(q[i].lca); for (int i = 1; i <= m; ++ i) printf("%d\n", ans[i]); return 0; }