CodeChef COUNTARI 分块 + FFT

CodeChef COUNTARI 分块 + FFT,第1张

题意

传送门 CodeChef - COUNTARI Arithmetic Progressions

题解

求满足 2 a j = a i + a k , i < j < k 2a_j = a_i + a_k,i2aj=ai+ak,i<j<k 的三元组数量。


直接卷积后枚举 j j j,难以处理非法三元组的贡献;枚举 j j j 对左右两侧进行卷积,时间复杂度过高。


考虑分块,块大小为 c c c


对于处于三个不同块的三元组, 枚举 j j j 所在块,对两侧块卷积,时间复杂度 O ( n 2 / c log ⁡ n ) O(n^2/c\log n) O(n2/clogn)


对于至少存在两个元素位于同一个块的三元组,暴力统计,时间复杂度 O ( n c ) O(nc) O(nc)


块大小取 O ( n log ⁡ n ) O(n\log n) O(nlogn),总时间复杂度 O ( n n log ⁡ n ) O(n\sqrt{n\log n}) O(nnlogn )


#include 
using namespace std;
using cp = complex<double>;
constexpr double PI = acos(-1.0);
vector<int> rev;
struct Poly : vector<cp>
{
    Poly() {}
    Poly(int n) : vector<cp>(n) {}
    Poly(const initializer_list<cp> &list) : vector<cp>(list) {}
    void fft(int n, bool inverse)
    {
        if ((int)rev.size() != n)
        {
            rev.resize(n);
            for (int i = 0; i < n; ++i)
                rev[i] = rev[i >> 1] >> 1 | (i & 1 ? n >> 1 : 0);
        }
        resize(n);
        for (int i = 0; i < n; ++i)
            if (i < rev[i])
                std::swap(at(i), at(rev[i]));

        for (int m = 1; m < n; m <<= 1)
        {
            int m2 = m << 1;
            for (int i = 0; i < n; i += m2)
            {
                cp w = cp(1, 0), _w = cp(cos(2 * PI / m2), sin(2 * PI / m2));
                if (inverse)
                    _w = conj(_w);
                for (int j = 0; j < m; ++j, w *= _w)
                {
                    cp &x = at(i + j), &y = at(i + j + m);
                    cp t = w * y;
                    y = x - t;
                    x += t;
                }
            }
        }
    }
    void dft(int n) { fft(n, 0); };
    void idft(int n)
    {
        fft(n, 1);
        for (int i = 0; i < n; ++i)
            at(i) /= n;
    }
    Poly operator*(const Poly &p) const
    {
        auto a = *this, b = p;
        int k = 1, n = a.size() + b.size() - 1;
        while (k < n)
            k <<= 1;
        a.dft(k), b.dft(k);
        for (int i = 0; i < k; ++i)
            a[i] *= b[i];
        a.idft(k);
        a.resize(n);
        return a;
    }
};
using ll = long long;
constexpr int MAXN = 1E5 + 5, MAXA = 3E4 + 5;
int N, A[MAXN];
int L[MAXN], R[MAXN];
int pre[MAXN * 2], nxt[MAXN * 2], tmp[MAXN * 2];

int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> N;
    for (int i = 0; i < N; ++i)
        cin >> A[i];
    int sz = 2 * ceil(sqrt(N * log2(N)));
    int n = 0;
    for (int i = 0; i < N; i += sz)
    {
        L[n] = i, R[n] = min(N, i + sz);
        ++n;
    }
    ll res = 0;
    for (int i = 0; i < N; ++i)
        ++nxt[A[i]];
    for (int i = 0; i < n; ++i)
    {
        for (int j = L[i]; j < R[i]; ++j)
            --nxt[A[j]];
        if (0 < i && i + 1 < n)
        {
            Poly f(MAXA), g(MAXA);
            for (int j = 0; j < MAXA; ++j)
                f[j].real(pre[j]), g[j].real(nxt[j]);
            f = f * g;
            for (int j = L[i]; j < R[i]; ++j)
                res += floor(f[2 * A[j]].real() + 0.5);
        }
        for (int j = L[i]; j < R[i]; ++j)
        {
            for (int k = j + 1; k < R[i]; ++k)
            {
                int a = A[k] * 2 - A[j];
                if (a >= 0)
                    res += nxt[a];
                a = A[j] * 2 - A[k];
                if (a >= 0)
                    res += pre[a] + tmp[a];
            }
            ++tmp[A[j]];
        }
        for (int j = L[i]; j < R[i]; ++j)
            ++pre[A[j]], --tmp[A[j]];
    }

    cout << res << '\n';
    return 0;
}

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/569783.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存