1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| #include <cstdio> #include <map> #include <algorithm> #include <cstring> #define MOD 1000003 using namespace std; struct edge { int v, nxt; } e[200001]; long long n, k, x, y, root, nume, tot, inv[MOD], h[100001], v[100001], size[100001], f[100001], val[100001]; bool vis[100001]; map<long long, int> id; void add_edge(int u, int v) { e[++nume].v = v, e[nume].nxt = h[u], h[u] = nume; e[++nume].v = u, e[nume].nxt = h[v], h[v] = nume; } void get_root(int t, int fa) { size[t] = 1, f[t] = 0; for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v] && e[i].v != fa) { get_root(e[i].v, t); size[t] += size[e[i].v]; f[t] = max(f[t], size[e[i].v]); } } f[t] = max(f[t], tot - size[t]); if (f[t] < f[root]) root = t; } void update(int a, int b) { if (a > b) swap(a, b); if (a < x) x = a, y = b; else if (a == x && b < y) y = b; } void get_dist(int t, int fa, int flag) { if (!flag) { if (!id.count(val[t])) id[val[t]] = t; else id[val[t]] = min(id[val[t]], t); } else { if (val[t] * val[root] % MOD == k) { if (t <= x || root <= x) update(t, root); } long long inverse = k * inv[val[t]] % MOD * inv[val[root]] % MOD; if (id.count(inverse)) { if (id[inverse] <= x || t <= x) update(id[inverse], t); } } for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v] && e[i].v != fa) { if (flag) val[e[i].v] = val[t] * v[e[i].v] % MOD; get_dist(e[i].v, t, flag); } } } void solve(int t) { vis[t] = true, val[t] = v[t], id.clear(); for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v]) { val[e[i].v] = v[e[i].v]; get_dist(e[i].v, t, 1); get_dist(e[i].v, t, 0); } } for (int i = h[t]; i; i = e[i].nxt) { if (!vis[e[i].v]) { root = 0, tot = size[e[i].v]; get_root(e[i].v, t); solve(root); } } } int main() { inv[1] = 1; for (int i = 2; i < MOD; ++i) { inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD; } while (scanf("%lld%lld", &n, &k) != -1) { x = y = 1e9, nume = 0; memset(vis, 0, sizeof(vis)); memset(h, 0, sizeof(h)); for (int i = 1; i <= n; ++i) scanf("%lld", &v[i]); for (int i = 1; i < n; ++i) { int u, v; scanf("%d%d", &u, &v); add_edge(u, v); } tot = f[0] = n, root = 0; get_root(1, 0); solve(root); if (y <= n) printf("%lld %lld\n", x, y); else printf("No solution\n"); } return 0; }
|