题目描述
Given two sequences $a$ and $b$ of equal length $n$, find the
$$
\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}
$$
题意概述
给定两个长度为$n$的序列$a$和$b$,求$\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$。
数据范围:$1 \le n \le 10^5, \ 1 \le a_i \le 500, \ 1 \le b_i \le 500$。
算法分析
$a_x+a_y+a_z+a_w$的取值范围为$[4,2000]$,$b_x \oplus b_y \oplus b_z \oplus b_w$的取值范围为$[0,511]$,可以分别计算每种情况出现的次数再求和。
令$f_{1,i,j}$表示有多少个$x$满足$a_x=i \land b_x=j$。对于$k \ge 2$,令
$$
f_{k,i,j}=\sum_{i_1+i_2=i} \sum_{j_1 \oplus j_2=j} f_{k-1,i_1,j_1}f_{1,i_2,j_2}
$$
$f_{4,i,j}$即要求的每种情况出现的次数。转移方程在一维上是乘法卷积,另一维上是异或卷积,可以分别用 FFT 和 FWT 进行处理。总时间复杂度为$O(a_{\max}b_{\max}\log(a_{\max}b_{\max}))$。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
| #include <cmath> #include <cstdio> #include <cstring> #include <algorithm>
int const N = 100005, M = 512, MOD = 998244353, G = 3, INV2 = 499122177;
int a[N], b[N], rev[M << 2];
int power(int a, int b) { int ret = 1; for (; b; b >>= 1) { if (b & 1) { ret = 1ll * ret * a % MOD; } a = 1ll * a * a % MOD; } return ret; }
int wn[M << 2], A[M << 2], f[M << 2][M];
void init(int n) { int len = 1, p = 0; for (; len < n; len <<= 1, ++p) ; for (int i = 1; i < len; ++i) { rev[i] = rev[i >> 1] >> 1 | (i & 1) << p - 1; } wn[0] = 1, wn[1] = power(G, (MOD - 1) / len); for (int i = 2; i < len >> 1; ++i) { wn[i] = 1ll * wn[i - 1] * wn[1] % MOD; } }
void fft(int *a, int len, bool inv = 0) { for (int i = 0; i < len; ++i) { if (i < rev[i]) { std::swap(a[i], a[rev[i]]); } } for (int i = 1; i < len; i <<= 1) { for (int j = 0; j < len; j += i << 1) { for (int k = 0; k < i; ++k) { int x = a[j + k], y = 1ll * wn[len / (i << 1) * k] * a[j + i + k] % MOD; a[j + k] = (x + y) % MOD; a[j + i + k] = (MOD + x - y) % MOD; } } } if (inv) { std::reverse(a + 1, a + len); int inv = power(len, MOD - 2); for (int i = 0; i < len; ++i) { a[i] = 1ll * a[i] * inv % MOD; } } }
void fwt(int *a, int len, bool inv = 0) { for (int i = 1; i < len; i <<= 1) { for (int j = 0; j < len; j += i << 1) { for (int k = 0; k < i; ++k) { int x = a[j + k], y = a[j + i + k]; a[j + k] = (x + y) % MOD; a[j + i + k] = (MOD + x - y) % MOD; if (inv) { a[j + k] = 1ll * a[j + k] * INV2 % MOD; a[j + i + k] = 1ll * a[j + i + k] * INV2 % MOD; } } } } }
int main() { int n; scanf("%d", &n); for (int i = 1; i <= n; ++i) { scanf("%d", &a[i]); } for (int i = 1; i <= n; ++i) { scanf("%d", &b[i]); f[a[i]][b[i]] = f[a[i]][b[i]] + 1; } init(M << 2); for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { A[j] = f[i][j]; } fwt(A, M); for (int j = 0; j < M; ++j) { f[i][j] = A[j]; } } for (int i = 0; i < M; ++i) { for (int j = 0; j < M << 2; ++j) { A[j] = f[j][i]; } fft(A, M << 2); for (int j = 0; j < M << 2; ++j) { f[j][i] = A[j]; } } for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { f[i][j] = power(f[i][j], 4); } } for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { A[j] = f[i][j]; } fwt(A, M, 1); for (int j = 0; j < M; ++j) { f[i][j] = A[j]; } } for (int i = 0; i < M; ++i) { for (int j = 0; j < M << 2; ++j) { A[j] = f[j][i]; } fft(A, M << 2, 1); for (int j = 0; j < M << 2; ++j) { f[j][i] = A[j]; } } int ans = 0; for (int i = 0; i < M << 2; ++i) { for (int j = 0; j < M; ++j) { if (f[i][j]) { ans = (ans + 1ll * f[i][j] * power(i, j)) % MOD; } } } printf("%d\n", ans); return 0; }
|