首页 > 技术文章 > 「hackerrank」Range Modular Queries

luckyblock 2020-09-07 21:48 原文

知识点: 分块,莫队

原题面:hackerrank


leaderdash 里一堆五星红旗老哥直接用 vector 水过去了= =


题意简述

给定一长度为 \(n\) 的数列 \(a\)\(q\) 次询问。
每次询问给定参数 \(l,r,x,y\),求:

\[\sum_{i=l}^{r} [a_i \bmod x = y] \]

\(1\le n,q,a_i\le 5\times 10^4\)


分析题意

数据范围比较喜人,\(n,q,a_i\) 同级,先考虑暴力。
考虑莫队搞掉区间限制,维护当前区间内各权值出现的次数。
\(cnt_{i}\) 表示当前区间内权值 \(i\) 出现的次数,显然答案为:

\[\sum_{k = 1}cnt_{y+kx} \]

复杂度 \(O(\text{Unknowen})\),过不了。


发现上述的算法查询复杂度与 \(x\) 有关。
\(x\) 较小时,查询复杂度较高,可达到 \(O(n)\) 级别。\(x\) 较大时复杂度较优秀。
考虑根号分治,对模数 \(x\le 200\)\(x > 200\) 的询问分开考虑。

对于 \(x\le 200\) 的询问,考虑分块。
设块大小为 \(\sqrt{n}\),预处理 \(f_{i,j,k}\) 表示前 i 个块中,% j = k 的数的个数。
预处理复杂度上界 \(O(200^3) \approx O(n\sqrt{n})\)
询问时整块直接 \(O(1)\) 查询预处理的前缀和,散块暴力。
单次查询复杂度 \(O(\sqrt{n})\)

对于 \(x> 200\) 的询问,套用上述莫队算法即可。
单次查询复杂度上界为 \(O(200) \approx O(\sqrt{n})\) 级别。

\(n,q\) 同级,总复杂度约为 \(O(n\sqrt{n})\) 级别,可过。


代码实现

#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#define ll long long
const int kMaxn = 4e4 + 10;
const int kMaxSqrtn = 210;
//=============================================================
struct Query {
  int l, r, mod, val, id;
} q[kMaxn];
int n, m, qnum, maxa, a[kMaxn], ans[kMaxn];
int block_size, block_num, sqrt_maxa, L[kMaxSqrtn], R[kMaxSqrtn], bel[kMaxn];
int f[kMaxSqrtn][kMaxSqrtn][kMaxSqrtn]; //f[i][j][k]:前 i 个块,% j = k 的数的个数。 
int nowl = 1, nowr, cnt[kMaxn]; //注意端点的初值 
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Getmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Getmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
void Debug() {
  printf("%d %d\n***********\n", block_size, block_num);
  for (int i = 1; i <= block_num; ++ i) {
    for (int j = 1; j <= block_size; ++ j) {
      for (int k = 0; k < j; ++ k) {
        printf("1~%d mod %d = %d appear %d times\n", i, j, k, f[i][j][k]);
      }
    }
  }
}
void PrepareBlock() {
  block_size = (int) sqrt(n);
  block_num = n / block_size;
  for (int i = 1; i <= block_num; ++ i) {
    L[i] = (i - 1) * block_size + 1;
    R[i] = i * block_size;
  }
  if (R[block_num] < n) {
    ++ block_num;
    L[block_num] = R[block_num - 1] + 1;
    R[block_num] = n;
  }
  for (int i = 1; i <= block_num; ++ i) {
    for (int j = L[i]; j <= R[i]; ++ j) {
      bel[j] = i;
    }
  }
  
  sqrt_maxa = (int) sqrt(maxa);
  for (int i = 1; i <= block_num; ++ i) {
    for (int j = 1; j <= sqrt_maxa; ++ j) {
      for (int k = L[i]; k <= R[i]; ++ k) {
        f[i][j][a[k] % j] ++;
      }
    }  
  }
  for (int i = 1; i <= block_num; ++ i) {
    for (int j = 1; j <= sqrt_maxa; ++ j) {
      for (int k = 0; k < j; ++ k) {
        f[i][j][k] += f[i - 1][j][k];
      }
    }
  }
//  Debug();
}
bool CompareQuery(Query fir_, Query sec_) {
  if (bel[fir_.l] != bel[sec_.l]) return bel[fir_.l] < bel[sec_.l];
  return fir_.r < sec_.r;
}
void Solve(int l_, int r_, int mod_, int val_, int id_) {
  if (bel[l_] == bel[r_]) {
    for (int i = l_; i <= r_; ++ i) {
      ans[id_] += (a[i] % mod_ == val_);
    }
    return ;
  }
  int bell = bel[l_], belr = bel[r_];
  ans[id_] += f[belr - 1][mod_][val_] - f[bell][mod_][val_];
  for (int i = l_; i <= R[bell]; ++ i) {
    ans[id_] += (a[i] % mod_ == val_);
  }
  for (int i = L[belr]; i <= r_; ++ i) {
    ans[id_] += (a[i] % mod_ == val_);
  }
}
void Add(int pos_) {
  cnt[a[pos_]] ++;
}
void Del(int pos_) {
  cnt[a[pos_]] --;
}
//=============================================================
int main() {
  n = read(), m = read();
  for (int i = 1; i <= n; ++ i) {
    a[i] = read();
    Getmax(maxa, a[i]); 
  }
  PrepareBlock();
  
  for (int i = 1; i <= m; ++ i) {
    int l = read() + 1, r = read() + 1;
    int mod = read(), val = read();
    if (mod <= sqrt_maxa) {
      Solve(l, r, mod, val, i);
    } else {
      q[++ qnum] = (Query) {l, r, mod, val, i};  
    }
  }
  
  std :: sort(q + 1, q + qnum + 1, CompareQuery);
  for (int i = 1; i <= qnum; ++ i) {
    int l = q[i].l, r = q[i].r, mod = q[i].mod, val = q[i].val;
    while (nowl > l) -- nowl, Add(nowl);
    while (nowr < r) ++ nowr, Add(nowr);
    while (nowl < l) Del(nowl), ++ nowl;
    while (nowr > r) Del(nowr), -- nowr;
    for (int j = val; j <= maxa; j += mod) {
      ans[q[i].id] += cnt[j];
    }
  }
  for (int i = 1; i <= m; ++ i) printf("%d\n", ans[i]);
  return 0;
}
/*
5 3
250 501 5000 5 4
0 4 5 0
0 4 10 0
0 4 3 2
*/

推荐阅读