Count on a Tree II

题目描述

You are given a tree with $N$ nodes. The tree nodes are numbered from $1$ to $N$. Each node has an integer weight.
We will ask you to perform the following operation:

  • $u$ $v$: ask for how many different integers that represent the weight of nodes there are on the path from $u$ to $v$.

题意概述

给定一棵$N$个节点的树,每个节点都有权值。有$M$次询问,每次询问给出$u, v$,求从$u$到$v$的路径上有多少种不同的权值。
数据范围:$1 \le N \le 40000, \; 1 \le M \le 10^5$。

算法分析

统计权值种类数,容易想到莫队,但这是在一棵树上,所以需要稍微做些转化。
求出树的括号序(DFS,访问和离开节点时均记录),那么每个询问都可以转化为对括号序上某个区间的询问。但是,如果区间内的某个节点出现了两次,那么这个节点不应该对答案产生贡献。因此,算法执行时可以用两个数组维护,一个表示节点的出现次数,另一个表示权值的出现次数。
具体来讲,令$s_i, e_i$分别表示访问、离开节点$i$的时间。假设$s_u \le s_v$。若$u$是$v$祖先,那么就相当于询问区间$[s_u, s_v]$;否则,相当于询问区间$[e_u, s_v]$,但此时$u, v$的LCA并没有被计算到答案中,因此需要把它加上。

代码

/*
 * Trap full -- please empty.
 */

#include <map>
#include <cstdio>
#include <cstring>
#include <algorithm>

template <typename T>
void read(T &n) {
  char c; for (; (c = getchar()) < '0' || c > '9'; ) ;
  for (n = c - '0'; (c = getchar()) >= '0' && c <= '9'; (n *= 10) += c - '0') ;
}

typedef int const ic;
typedef long long ll;

static ic N = 40005;
static ic M = 100005;
static ic K = 300;
std::map <int, int> num;
int nume, h[N], w[N], base[N], seq[N << 1];
int tim, st[N], ed[N], dep[N], up[N][16];
int mans, ml, mr, mvis[N], mrec[N], ans[M];
struct Edge {
  int v, nxt;
} e[N << 1];
struct Query {
  int l, r, lca, id;
  bool operator < (const Query &q) const {
    return l / K == q.l / K ? r < q.r : l / K < q.l / K;
  }
} q[M];

void add_edge(ic &u, ic &v) {
  e[++ nume] = (Edge) { v, h[u] }, h[u] = nume;
  e[++ nume] = (Edge) { u, h[v] }, h[v] = nume;
}

void dfs(ic &t, ic &fa) {
  seq[st[t] = ++ tim] = t, dep[t] = dep[fa] + 1, up[t][0] = fa;
  for (int i = h[t]; i; i = e[i].nxt)
    if (e[i].v != fa) dfs(e[i].v, t);
  seq[ed[t] = ++ tim] = t;
}

int get_lca(int u, int v) {
  if (dep[u] > dep[v]) std::swap(u, v);
  for (int i = 15; ~ i; -- i)
    if (dep[up[v][i]] >= dep[u]) v = up[v][i];
  if (u == v) return u;
  for (int i = 15; ~ i; -- i)
    if (up[u][i] != up[v][i]) u = up[u][i], v = up[v][i];
  return up[u][0];
}

void add(ic &u) {
  if (mvis[u]) {
    -- mrec[w[u]], mvis[u] = 0;
    if (! mrec[w[u]]) -- mans;
  } else {
    if (! mrec[w[u]]) ++ mans;
    ++ mrec[w[u]], mvis[u] = 1;
  }
}

void get(ic &l, ic &r) {
  for (; ml < l;) add(seq[ml ++]);
  for (; ml > l;) add(seq[-- ml]);
  for (; mr < r;) add(seq[++ mr]);
  for (; mr > r;) add(seq[mr --]);
}

int main() {
  int n, m; read(n), read(m);
  for (int i = 1; i <= n; ++ i) {
    read(w[i]);
    if (! num.count(w[i])) num[w[i]] = num.size();
    w[i] = num[w[i]];
  }
  for (int i = 1, u, v; i < n; ++ i) read(u), read(v), add_edge(u, v);
  dfs(1, 0);
  for (int i = 1; i < 16; ++ i)
    for (int j = 1; j <= n; ++ j) up[j][i] = up[up[j][i - 1]][i - 1];
  for (int i = 1, u, v; i <= m; ++ i) {
    read(u), read(v), q[i].id = i;
    if (st[u] > st[v]) std::swap(u, v);
    if (ed[u] > ed[v]) q[i].l = st[u] + 1, q[i].r = st[v], q[i].lca = u;
    else q[i].l = ed[u], q[i].r = st[v], q[i].lca = get_lca(u, v);
  }
  std::sort(q + 1, q + m + 1);
  mans = ml = mr = mvis[1] = mrec[w[1]] = 1;
  for (int i = 1; i <= m; ++ i) get(q[i].l, q[i].r), add(q[i].lca), ans[q[i].id] = mans, add(q[i].lca);
  for (int i = 1; i <= m; ++ i) printf("%d\n", ans[i]);
  return 0;
}

RegMs If

418 I'm a teapot

Leave a Reply

Your email address will not be published. Required fields are marked *