# Count on a Tree II

## 题目描述

You are given a tree with $N$ nodes. The tree nodes are numbered from $1$ to $N$. Each node has an integer weight.
We will ask you to perform the following operation:

• $u$ $v$: ask for how many different integers that represent the weight of nodes there are on the path from $u$ to $v$.

## 代码

/*
* Trap full -- please empty.
*/

#include <map>
#include <cstdio>
#include <cstring>
#include <algorithm>

template <typename T>
char c; for (; (c = getchar()) < '0' || c > '9'; ) ;
for (n = c - '0'; (c = getchar()) >= '0' && c <= '9'; (n *= 10) += c - '0') ;
}

typedef int const ic;
typedef long long ll;

static ic N = 40005;
static ic M = 100005;
static ic K = 300;
std::map <int, int> num;
int nume, h[N], w[N], base[N], seq[N << 1];
int tim, st[N], ed[N], dep[N], up[N][16];
int mans, ml, mr, mvis[N], mrec[N], ans[M];
struct Edge {
int v, nxt;
} e[N << 1];
struct Query {
int l, r, lca, id;
bool operator < (const Query &q) const {
return l / K == q.l / K ? r < q.r : l / K < q.l / K;
}
} q[M];

void add_edge(ic &u, ic &v) {
e[++ nume] = (Edge) { v, h[u] }, h[u] = nume;
e[++ nume] = (Edge) { u, h[v] }, h[v] = nume;
}

void dfs(ic &t, ic &fa) {
seq[st[t] = ++ tim] = t, dep[t] = dep[fa] + 1, up[t][0] = fa;
for (int i = h[t]; i; i = e[i].nxt)
if (e[i].v != fa) dfs(e[i].v, t);
seq[ed[t] = ++ tim] = t;
}

int get_lca(int u, int v) {
if (dep[u] > dep[v]) std::swap(u, v);
for (int i = 15; ~ i; -- i)
if (dep[up[v][i]] >= dep[u]) v = up[v][i];
if (u == v) return u;
for (int i = 15; ~ i; -- i)
if (up[u][i] != up[v][i]) u = up[u][i], v = up[v][i];
return up[u][0];
}

if (mvis[u]) {
-- mrec[w[u]], mvis[u] = 0;
if (! mrec[w[u]]) -- mans;
} else {
if (! mrec[w[u]]) ++ mans;
++ mrec[w[u]], mvis[u] = 1;
}
}

void get(ic &l, ic &r) {
for (; ml < l;) add(seq[ml ++]);
for (; ml > l;) add(seq[-- ml]);
for (; mr < r;) add(seq[++ mr]);
for (; mr > r;) add(seq[mr --]);
}

int main() {
for (int i = 1; i <= n; ++ i) {
if (! num.count(w[i])) num[w[i]] = num.size();
w[i] = num[w[i]];
}
dfs(1, 0);
for (int i = 1; i < 16; ++ i)
for (int j = 1; j <= n; ++ j) up[j][i] = up[up[j][i - 1]][i - 1];
for (int i = 1, u, v; i <= m; ++ i) {
if (st[u] > st[v]) std::swap(u, v);
if (ed[u] > ed[v]) q[i].l = st[u] + 1, q[i].r = st[v], q[i].lca = u;
else q[i].l = ed[u], q[i].r = st[v], q[i].lca = get_lca(u, v);
}
std::sort(q + 1, q + m + 1);
mans = ml = mr = mvis[1] = mrec[w[1]] = 1;
for (int i = 1; i <= m; ++ i) get(q[i].l, q[i].r), add(q[i].lca), ans[q[i].id] = mans, add(q[i].lca);
for (int i = 1; i <= m; ++ i) printf("%d\n", ans[i]);
return 0;
}


418 I'm a teapot