# Four Loop

## 题目描述

Given two sequences $a$ and $b$ of equal length $n$, find the

$$\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$$

## 算法分析

$a_x+a_y+a_z+a_w$的取值范围为$[4,2000]$，$b_x \oplus b_y \oplus b_z \oplus b_w$的取值范围为$[0,511]$，可以分别计算每种情况出现的次数再求和。

$$f_{k,i,j}=\sum_{i_1+i_2=i} \sum_{j_1 \oplus j_2=j} f_{k-1,i_1,j_1}f_{1,i_2,j_2}$$

$f_{4,i,j}$即要求的每种情况出现的次数。转移方程在一维上是乘法卷积，另一维上是异或卷积，可以分别用FFT和FWT进行处理。总时间复杂度为$O(a_{\max}b_{\max}\log(a_{\max}b_{\max}))$。

## 代码

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>

int const N = 100005, M = 512, MOD = 998244353, G = 3, INV2 = 499122177;

int a[N], b[N], rev[M << 2];

int power(int a, int b) {
int ret = 1;
for (; b; b >>= 1) {
if (b & 1) {
ret = 1ll * ret * a % MOD;
}
a = 1ll * a * a % MOD;
}
return ret;
}

int wn[M << 2], A[M << 2], f[M << 2][M];

void init(int n) {
int len = 1, p = 0;
for (; len < n; len <<= 1, ++p) ;
for (int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << p - 1;
}
wn[0] = 1, wn[1] = power(G, (MOD - 1) / len);
for (int i = 2; i < len >> 1; ++i) {
wn[i] = 1ll * wn[i - 1] * wn[1] % MOD;
}
}

void fft(int *a, int len, bool inv = 0) {
for (int i = 0; i < len; ++i) {
if (i < rev[i]) {
std::swap(a[i], a[rev[i]]);
}
}
for (int i = 1; i < len; i <<= 1) {
for (int j = 0; j < len; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = 1ll * wn[len / (i << 1) * k] * a[j + i + k] % MOD;
a[j + k] = (x + y) % MOD;
a[j + i + k] = (MOD + x - y) % MOD;
}
}
}
if (inv) {
std::reverse(a + 1, a + len);
int inv = power(len, MOD - 2);
for (int i = 0; i < len; ++i) {
a[i] = 1ll * a[i] * inv % MOD;
}
}
}

void fwt(int *a, int len, bool inv = 0) {
for (int i = 1; i < len; i <<= 1) {
for (int j = 0; j < len; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = a[j + i + k];
a[j + k] = (x + y) % MOD;
a[j + i + k] = (MOD + x - y) % MOD;
if (inv) {
a[j + k] = 1ll * a[j + k] * INV2 % MOD;
a[j + i + k] = 1ll * a[j + i + k] * INV2 % MOD;
}
}
}
}
}

int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n; ++i) {
scanf("%d", &b[i]);
f[a[i]][b[i]] = f[a[i]][b[i]] + 1;
}
init(M << 2);
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
A[j] = f[i][j];
}
fwt(A, M);
for (int j = 0; j < M; ++j) {
f[i][j] = A[j];
}
}
for (int i = 0; i < M; ++i) {
for (int j = 0; j < M << 2; ++j) {
A[j] = f[j][i];
}
fft(A, M << 2);
for (int j = 0; j < M << 2; ++j) {
f[j][i] = A[j];
}
}
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
f[i][j] = power(f[i][j], 4);
}
}
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
A[j] = f[i][j];
}
fwt(A, M, 1);
for (int j = 0; j < M; ++j) {
f[i][j] = A[j];
}
}
for (int i = 0; i < M; ++i) {
for (int j = 0; j < M << 2; ++j) {
A[j] = f[j][i];
}
fft(A, M << 2, 1);
for (int j = 0; j < M << 2; ++j) {
f[j][i] = A[j];
}
}
int ans = 0;
for (int i = 0; i < M << 2; ++i) {
for (int j = 0; j < M; ++j) {
if (f[i][j]) {
ans = (ans + 1ll * f[i][j] * power(i, j)) % MOD;
}
}
}
printf("%d\n", ans);
return 0;
}

418 I'm a teapot