Four Loop

题目描述

Given two sequences $a$ and $b$ of equal length $n$, find the

$$\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$$

题意概述

给定两个长度为$n$的序列$a$和$b$,求$\sum_{x=1}^n \sum_{y=1}^n \sum_{z=1}^n \sum_{w=1}^n (a_x+a_y+a_z+a_w)^{(b_x \oplus b_y \oplus b_z \oplus b_w)}$。

数据范围:$1 \le n \le 10^5, \; 1 \le a_i \le 500, \; 1 \le b_i \le 500$。

算法分析

$a_x+a_y+a_z+a_w$的取值范围为$[4,2000]$,$b_x \oplus b_y \oplus b_z \oplus b_w$的取值范围为$[0,511]$,可以分别计算每种情况出现的次数再求和。

令$f_{1,i,j}$表示有多少个$x$满足$a_x=i \land b_x=j$。对于$k \ge 2$,令

$$f_{k,i,j}=\sum_{i_1+i_2=i} \sum_{j_1 \oplus j_2=j} f_{k-1,i_1,j_1}f_{1,i_2,j_2}$$

$f_{4,i,j}$即要求的每种情况出现的次数。转移方程在一维上是乘法卷积,另一维上是异或卷积,可以分别用FFT和FWT进行处理。总时间复杂度为$O(a_{\max}b_{\max}\log(a_{\max}b_{\max}))$。

代码

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
 
int const N = 100005, M = 512, MOD = 998244353, G = 3, INV2 = 499122177;
 
int a[N], b[N], rev[M << 2];
 
int power(int a, int b) {
    int ret = 1;
    for (; b; b >>= 1) {
        if (b & 1) {
            ret = 1ll * ret * a % MOD;
        }
        a = 1ll * a * a % MOD;
    }
    return ret;
}
 
int wn[M << 2], A[M << 2], f[M << 2][M];
 
void init(int n) {
    int len = 1, p = 0;
    for (; len < n; len <<= 1, ++p) ;
    for (int i = 1; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1 | (i & 1) << p - 1;
    }
    wn[0] = 1, wn[1] = power(G, (MOD - 1) / len);
    for (int i = 2; i < len >> 1; ++i) {
        wn[i] = 1ll * wn[i - 1] * wn[1] % MOD;
    }
}
 
void fft(int *a, int len, bool inv = 0) {
    for (int i = 0; i < len; ++i) {
        if (i < rev[i]) {
            std::swap(a[i], a[rev[i]]);
        }
    }
    for (int i = 1; i < len; i <<= 1) {
        for (int j = 0; j < len; j += i << 1) {
            for (int k = 0; k < i; ++k) {
                int x = a[j + k], y = 1ll * wn[len / (i << 1) * k] * a[j + i + k] % MOD;
                a[j + k] = (x + y) % MOD;
                a[j + i + k] = (MOD + x - y) % MOD;
            }
        }
    }
    if (inv) {
        std::reverse(a + 1, a + len);
        int inv = power(len, MOD - 2);
        for (int i = 0; i < len; ++i) {
            a[i] = 1ll * a[i] * inv % MOD;
        }
    }
}
 
void fwt(int *a, int len, bool inv = 0) {
    for (int i = 1; i < len; i <<= 1) {
        for (int j = 0; j < len; j += i << 1) {
            for (int k = 0; k < i; ++k) {
                int x = a[j + k], y = a[j + i + k];
                a[j + k] = (x + y) % MOD;
                a[j + i + k] = (MOD + x - y) % MOD;
                if (inv) {
                    a[j + k] = 1ll * a[j + k] * INV2 % MOD;
                    a[j + i + k] = 1ll * a[j + i + k] * INV2 % MOD;
                }
            }
        }
    }
}
 
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) {
        scanf("%d", &a[i]);
    }
    for (int i = 1; i <= n; ++i) {
        scanf("%d", &b[i]);
        f[a[i]][b[i]] = f[a[i]][b[i]] + 1;
    }
    init(M << 2);
    for (int i = 0; i < M << 2; ++i) {
        for (int j = 0; j < M; ++j) {
            A[j] = f[i][j];
        }
        fwt(A, M);
        for (int j = 0; j < M; ++j) {
            f[i][j] = A[j];
        }
    }
    for (int i = 0; i < M; ++i) {
        for (int j = 0; j < M << 2; ++j) {
            A[j] = f[j][i];
        }
        fft(A, M << 2);
        for (int j = 0; j < M << 2; ++j) {
            f[j][i] = A[j];
        }
    }
    for (int i = 0; i < M << 2; ++i) {
        for (int j = 0; j < M; ++j) {
            f[i][j] = power(f[i][j], 4);
        }
    }
    for (int i = 0; i < M << 2; ++i) {
        for (int j = 0; j < M; ++j) {
            A[j] = f[i][j];
        }
        fwt(A, M, 1);
        for (int j = 0; j < M; ++j) {
            f[i][j] = A[j];
        }
    }
    for (int i = 0; i < M; ++i) {
        for (int j = 0; j < M << 2; ++j) {
            A[j] = f[j][i];
        }
        fft(A, M << 2, 1);
        for (int j = 0; j < M << 2; ++j) {
            f[j][i] = A[j];
        }
    }
    int ans = 0;
    for (int i = 0; i < M << 2; ++i) {
        for (int j = 0; j < M; ++j) {
            if (f[i][j]) {
                ans = (ans + 1ll * f[i][j] * power(i, j)) % MOD;
            }
        }
    }
    printf("%d\n", ans);
    return 0;
}

RegMs If

418 I'm a teapot

Leave a Reply

Your email address will not be published. Required fields are marked *