题目描述
Given two sequences $a$ and $b$ of equal length $n$, find the
$$\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$$
题意概述
给定两个长度为$n$的序列$a$和$b$,求$\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$。
数据范围:$1 \le n \le 10^5, \; 1 \le a_i \le 500, \; 1 \le b_i \le 500$。
算法分析
$a_x+a_y+a_z+a_w$的取值范围为$[4,2000]$,$b_x \oplus b_y \oplus b_z \oplus b_w$的取值范围为$[0,511]$,可以分别计算每种情况出现的次数再求和。
令$f_{1,i,j}$表示有多少个$x$满足$a_x=i \land b_x=j$。对于$k \ge 2$,令
$$f_{k,i,j}=\sum_{i_1+i_2=i} \sum_{j_1 \oplus j_2=j} f_{k-1,i_1,j_1}f_{1,i_2,j_2}$$
$f_{4,i,j}$即要求的每种情况出现的次数。转移方程在一维上是乘法卷积,另一维上是异或卷积,可以分别用FFT和FWT进行处理。总时间复杂度为$O(a_{\max}b_{\max}\log(a_{\max}b_{\max}))$。
代码
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> int const N = 100005, M = 512, MOD = 998244353, G = 3, INV2 = 499122177; int a[N], b[N], rev[M << 2]; int power(int a, int b) { int ret = 1; for (; b; b >>= 1) { if (b & 1) { ret = 1ll * ret * a % MOD; } a = 1ll * a * a % MOD; } return ret; } int wn[M << 2], A[M << 2], f[M << 2][M]; void init(int n) { int len = 1, p = 0; for (; len < n; len <<= 1, ++p) ; for (int i = 1; i < len; ++i) { rev[i] = rev[i >> 1] >> 1 | (i & 1) << p - 1; } wn[0] = 1, wn[1] = power(G, (MOD - 1) / len); for (int i = 2; i < len >> 1; ++i) { wn[i] = 1ll * wn[i - 1] * wn[1] % MOD; } } void fft(int *a, int len, bool inv = 0) { for (int i = 0; i < len; ++i) { if (i < rev[i]) { std::swap(a[i], a[rev[i]]); } } for (int i = 1; i < len; i <<= 1) { for (int j = 0; j < len; j += i << 1) { for (int k = 0; k < i; ++k) { int x = a[j + k], y = 1ll * wn[len / (i << 1) * k] * a[j + i + k] % MOD; a[j + k] = (x + y) % MOD; a[j + i + k] = (MOD + x - y) % MOD; } } } if (inv) { std::reverse(a + 1, a + len); int inv = power(len, MOD - 2); for (int i = 0; i < len; ++i) { a[i] = 1ll * a[i] * inv % MOD; } } } void fwt(int *a, int len, bool inv = 0) { for (int i = 1; i < len; i <<= 1) { for (int j = 0; j < len; j += i << 1) { for (int k = 0; k < i; ++k) { int x = a[j + k], y = a[j + i + k]; a[j + k] = (x + y) % MOD; a[j + i + k] = (MOD + x - y) % MOD; if (inv) { a[j + k] = 1ll * a[j + k] * INV2 % MOD; a[j + i + k] = 1ll * a[j + i + k] * INV2 % MOD; } } } } } int main() { int n; scanf("%d", &n); for (int i = 1; i <= n; ++i) { scanf("%d", &a[i]); } for (int i = 1; i <= n; ++i) { scanf("%d", &b[i]); f[a[i]][b[i]] = f[a[i]][b[i]] + 1; } init(M << 2); for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { A[j] = f[i][j]; } fwt(A, M); for (int j = 0; j < M; ++j) { f[i][j] = A[j]; } } for (int i = 0; i < M; ++i) { for (int j = 0; j < M << 2; ++j) { A[j] = f[j][i]; } fft(A, M << 2); for (int j = 0; j < M << 2; ++j) { f[j][i] = A[j]; } } for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { f[i][j] = power(f[i][j], 4); } } for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { A[j] = f[i][j]; } fwt(A, M, 1); for (int j = 0; j < M; ++j) { f[i][j] = A[j]; } } for (int i = 0; i < M; ++i) { for (int j = 0; j < M << 2; ++j) { A[j] = f[j][i]; } fft(A, M << 2, 1); for (int j = 0; j < M << 2; ++j) { f[j][i] = A[j]; } } int ans = 0; for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { if (f[i][j]) { ans = (ans + 1ll * f[i][j] * power(i, j)) % MOD; } } } printf("%d\n", ans); return 0; }