Digit Tree

题目描述

ZS the Coder has a large tree. It can be represented as an undirected connected graph of $n$ vertices numbered from $0$ to $n-1$ and $n-1$ edges between them. There is a single nonzero digit written on each edge.

One day, ZS the Coder was bored and decided to investigate some properties of the tree. He chose a positive integer $M$, which is coprime to $10$, i.e. $(M, 10)=1$.

ZS consider an ordered pair of distinct vertices $(u, v)$ interesting when if he would follow the shortest path from vertex $u$ to vertex $v$ and write down all the digits he encounters on his path in the same order, he will get a decimal representaion of an integer divisible by $M$.

Formally, ZS consider an ordered pair of distinct vertices $(u, v)$ interesting if the following states true:

  • let $a_1=u, a_2, \ldots, a_k=v$ be the sequence of vertices on the shortest path from $u$ to $v$ in the order of encountering them;
  • let $d_i (1 \le i \lt k)$ be the digit written on the edge between vertices $a_i$ and $a_i+1$;
  • the integer $\overline{d_1d_2 \ldots d_{k-1}}=\sum_{i=1}^{k-1} 10^{k-1-i}d_i$ is divisible by $M$.

Help ZS the Coder find the number of interesting pairs!

题意概述

给定一棵有$n$个节点的树和一个与$10$互质的数$M$,树上每条边的权值都是小于$10$的正整数。定义$dist_{u, v}$为依次写下从$u$到$v$路径上每条边的权值所得到的数字。求满足$M \mid dist_{u, v}$的点对个数。

数据范围:$2 \le n \le 10^5, \ 1 \le M \le 10^9$。

算法分析

设当前枚举到的节点为$x$。令$depth_u$表示$u$在$x$及它子树中的深度。对于在$x$第$(i+1)$棵子树中的节点$u$和在前$i$棵子树中的节点$v$,有:

$$
\begin{align}
M \mid dist_{u, v} \Leftrightarrow 10^{depth_v}dist_{u, x}+dist_{x, v} \equiv 0 \pmod M \\
M \mid dist_{v, u} \Leftrightarrow 10^{depth_u}dist_{v, x}+dist_{x, u} \equiv 0 \pmod M
\end{align}
$$

对于$(1)$式,化简得$dist_{u, x} \equiv -10^{-depth_v}dist_{x, v} \pmod M$;对于$(2)$式,化简得$10^{-depth_u}dist_{x, u} \equiv -dist_{v, x} \pmod M$。用两个 map 分别存下前$i$棵子树中$10^{-depth_v}dist_{x, v}$和$dist_{v, x}$的值,在处理第$(i+1)$棵子树时直接加上可行的方案数。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <cstdio>
#include <map>
#include <algorithm>
using namespace std;
struct edge {
int v, w, nxt;
} e[200001];
long long n, m, ans, nume, tot, root, h[100001], size[100001], f[100001];
long long val1[100001], val2[100001], power[100001], inv[100001];
bool vis[100001];
map<long long, int> id1, id2;
void extend_gcd(int a, int b, int &x, int &y) {
if (!b) { x = 1, y = 0; return; }
extend_gcd(b, a % b, y, x);
y -= a / b * x;
}
void add_edge(int u, int v, int w) {
e[++nume].v = v, e[nume].w = w, e[nume].nxt = h[u], h[u] = nume;
e[++nume].v = u, e[nume].w = w, e[nume].nxt = h[v], h[v] = nume;
}
void get_root(int t, int fa) {
size[t] = 1, f[t] = 0;
for (int i = h[t]; i; i = e[i].nxt) {
if (!vis[e[i].v] && e[i].v != fa) {
get_root(e[i].v, t);
size[t] += size[e[i].v];
f[t] = max(f[t], size[e[i].v]);
}
}
f[t] = max(f[t], tot - size[t]);
if (f[t] < f[root]) root = t;
}
void get_dist(int t, int fa, int flag, int depth) {
if (!flag) ++id1[val1[t]], ++id2[val2[t] * inv[depth] % m];
else {
ans += !val1[t] + !val2[t];
ans += id1[(val2[t] ? m - val2[t] : 0) * inv[depth] % m];
ans += id2[val1[t] ? m - val1[t] : 0];
}
for (int i = h[t]; i; i = e[i].nxt) {
if (!vis[e[i].v] && e[i].v != fa) {
if (flag) {
val1[e[i].v] = (val1[t] + e[i].w * power[depth]) % m;
val2[e[i].v] = (val2[t] * 10 + e[i].w) % m;
}
get_dist(e[i].v, t, flag, depth + 1);
}
}
}
void solve(int t) {
vis[t] = true, id1.clear(), id2.clear();
for (int i = h[t]; i; i = e[i].nxt) {
if (!vis[e[i].v]) {
val1[e[i].v] = val2[e[i].v] = e[i].w % m;
get_dist(e[i].v, t, 1, 1);
get_dist(e[i].v, t, 0, 1);
}
}
for (int i = h[t]; i; i = e[i].nxt) {
if (!vis[e[i].v]) {
root = n, tot = size[e[i].v];
get_root(e[i].v, t);
solve(root);
}
}
}
int main()
{
scanf("%lld%lld", &n, &m);
power[0] = 1;
for (int i = 1; i <= n; ++i) power[i] = power[i - 1] * 10 % m;
for (int i = 0; i <= n; ++i) {
int x, y;
extend_gcd(power[i], m, x, y);
inv[i] = (x % m + m) % m;
}
for (int i = 1; i < n; ++i) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add_edge(u, v, w);
}
tot = f[n] = n, root = n;
get_root(0, n);
solve(root);
printf("%lld\n", ans);
return 0;
}

Digit Tree
https://regmsif.cf/2017/07/02/oi/digit-tree/
作者
RegMs If
发布于
2017年7月2日
许可协议