题目描述
Our child likes computer science very much, especially he likes binary trees.
Consider the sequence of $n$ distinct positive integers: $c_1, c_2, \ldots, c_n$. The child calls a vertex-weighted rooted binary tree good if and only if for every vertex $v$, the weight of $v$ is in the set ${c_1, c_2, \ldots, c_n}$. Also our child thinks that the weight of a vertex-weighted tree is the sum of all vertices’ weights.
Given an integer $m$, can you for all $s \; (1 \le s \le m)$ calculate the number of good vertex-weighted rooted binary trees with weight $s$? Please, check the samples for better understanding what trees are considered different.
We only want to know the answer modulo $998244353$ ($7 \times 17 \times 2^{23}+1$, a prime number).
题意概述
有$n$种点,每种点有无限个,第$i$种点的权值为$c_i$。定义一棵二叉树的权值等于它所有点的权值之和。求对于所有$s \in [1, m]$,权值为$s$的二叉树有几棵。两棵二叉树不同当且仅当它们左子树或右子树不同,或者根节点权值不同。
数据范围:$1 \le n, m, c_i \le 10^5$。
算法分析
令$f(x)$表示权值为$x$的二叉树个数,$F(x)$为其生成函数($F(x)=\sum_{i \ge 0} f(i)x^i$)。
令$C(x)$为给定$c$的集合的生成函数($C(x)=\sum_{i=1}^n x^{c_i}$)。
根据DP转移方程,易知
$$
f(x)=\sum_{w \in {c_1, c_2, \ldots, c_n}} \sum_{i=0}^{x-w} f(i)f(x-w-i)
$$
即
$$
F(x)=C(x)F(x)^2+1
$$
解得
$$
F(x)={1 \pm \sqrt{1-4C(x)} \over 2C(x)}={2 \over 1 \pm \sqrt{1-4C(x)}}
$$
显然,若取减号,则当$x$趋近$0$时分母为$0$,因此只能取加号。接着就是多项式开根和多项式求逆了。
- 多项式求逆:
求$GF \equiv 1 \pmod {x^n}$。假设已知$G_0F \equiv 1 \pmod {x^{\lceil n/2 \rceil}}$
$G-G_0 \equiv 0 \pmod {x^{\lceil n/2 \rceil}}$
$G^2-2GG_0+G_0^2 \equiv 0 \pmod {x^n}$
$G-2G_0+G_0^2F \equiv 0 \pmod {x^n}$
$G \equiv 2G_0-G_0^2F \pmod {x^n}$ - 多项式开根:
求$G^2 \equiv F \pmod {x^n}$。假设已知$G_0^2 \equiv F \pmod {x^{\lceil n/2 \rceil}}$
$(G_0^2-F)^2 \equiv 0 \pmod {x^n}$
$(G_0^2+F)^2 \equiv 4G_0^2F \pmod {x^n}$
$\left({G_0^2+F \over 2G_0}\right)^2 \equiv F \pmod {x^n}$
$G \equiv {G_0+G_0^{-1}F \over 2} \pmod {x^n}$
代码
#include <algorithm> #include <cstdio> #include <cstring> static const int N = 500000; static const int MOD = 998244353; static const int G = 3; static const int INV2 = 499122177; int n, m, c[N], C[N], rev[N], wn[N], tmp[N], tmp2[N], tmp3[N]; int power(int a, int b) { int ret = 1; for (a %= MOD, b %= MOD - 1; b; b >>= 1) b & 1 && (ret = 1ll * ret * a % MOD), a = 1ll * a * a % MOD; return ret; } void init(int &n) { int m = n << 1, l = 0; for (n = 1; n < m; n <<= 1, ++l) ; for (int i = 1; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << (l - 1); } void ntt(int *a, int n, bool inv) { for (int i = 0; i < n; ++i) if (i < rev[i]) std::swap(a[i], a[rev[i]]); wn[0] = 1, wn[1] = power(G, (MOD - 1) / n); for (int i = 2; i < n >> 1; ++i) wn[i] = 1ll * wn[i - 1] * wn[1] % MOD; for (int i = 1; i < n; i <<= 1) for (int j = 0; j < n; j += i << 1) for (int k = 0; k < i; ++k) { int x = a[j + k], y = 1ll * wn[n / (i << 1) * k] * a[j + k + i] % MOD; a[j + k] = (x + y) % MOD, a[j + k + i] = (MOD + x - y) % MOD; } if (inv) { for (int i = 1; i < n >> 1; ++i) std::swap(a[i], a[n - i]); int rec = power(n, MOD - 2); for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * rec % MOD; } } void get_inv(int *f, int *g, int n) { if (n == 1) return void(g[0] = power(f[0], MOD - 2)); int rec = n; get_inv(f, g, (n + 1) >> 1), init(n); for (int i = (rec + 1) >> 1; i < n; ++i) g[i] = 0; for (int i = 0; i < rec; ++i) tmp[i] = f[i]; for (int i = rec; i < n; ++i) tmp[i] = 0; ntt(g, n, 0), ntt(tmp, n, 0); for (int i = 0; i < n; ++i) g[i] = 1ll * g[i] * (MOD + 2 - 1ll * g[i] * tmp[i] % MOD) % MOD; ntt(g, n, 1); for (int i = rec; i < n; ++i) g[i] = 0; } void get_sqrt(int *f, int *g, int n) { if (n == 1) return void(g[0] = 1); int rec = n; get_sqrt(f, g, (n + 1) >> 1); for (int i = 0; i<(n + 1)>> 1; ++i) tmp2[i] = g[i]; for (int i = (n + 1) >> 1; i < n; ++i) tmp2[i] = 0; get_inv(tmp2, tmp3, n), init(n); for (int i = (rec + 1) >> 1; i < n; ++i) g[i] = 0; for (int i = 0; i < rec; ++i) tmp2[i] = f[i]; for (int i = rec; i < n; ++i) tmp2[i] = 0; ntt(tmp2, n, 0), ntt(tmp3, n, 0); for (int i = 0; i < n; ++i) tmp3[i] = 1ll * tmp3[i] * tmp2[i] % MOD; ntt(tmp3, n, 1); for (int i = 0; i < rec; ++i) g[i] = 1ll * (g[i] + tmp3[i]) * INV2 % MOD; for (int i = rec; i < n; ++i) g[i] = 0; } int main() { scanf("%d%d", &n, &m); for (int i = 0; i < n; ++i) scanf("%d", &c[i]); for (int i = 0; i < n; ++i) if (c[i] <= m) ++C[c[i]]; for (int i = 1; i <= m; ++i) C[i] = (MOD - (C[i] << 2)) % MOD; C[0] = 1; get_sqrt(C, c, m + 1), ++c[0], get_inv(c, C, m + 1); for (int i = 1; i <= m; ++i) printf("%d\n", (C[i] << 1) % MOD); return 0; }