首页 > 技术文章 > P2839 [国家集训队]middle

luckyblock 2020-10-09 11:27 原文

知识点:二分答案,可持久化线段树

原题面: Luogu

用到了和这题同样的套路:「TJOI / HEOI2016」排序

简述

对于一长度为 \(n\) 的数列 \(a\),定义其「中位数」为排序后数列的第 \(\frac{n}{2}\) 个数,下标从 0 开始
给定一长度为 \(n\) 的数列 \(a\)\(q\) 次询问,每次询问给定参数 \(a,b,c,d\),保证 \(a<b<c<d\)。求区间左端点在 \([a,b]\) 之间,右端点在 \([c,d]\) 之间的子区间中,最大的「中位数」,「中位数」的定义见上文。
强制在线
\(1\le n\le 2\times 10^4\)\(1\le q\le 2.5\times 10^4\)
2S,512MB。

分析题意

国际惯例先离散化,以下均在离散化的前提下展开。
对于一个询问 \((a,b,c,d)\),可拆成三部分考虑:\([b,c]\) 内必选的部分和 \([a,b)\) 的一段后缀,\((c,d]\) 的一段前缀。仅需考虑什么样的前后缀 对答案有贡献即可。

假设已选出一段区间 \([l,r]\),有 \(a<l\le b\)\(c\le r<d\),设其中位数为 \(mid\)。显然,此时从左右两侧向区间内添加不小于 \(mid\) 的数可能会使中位数增大,添加小于 \(mid\) 的数可能会使中位数减小。
手玩一下可发现,当且仅当 添加的不小于 \(mid\) 的数的个数 大于等于 添加的小于 \(mid\) 的个数时,中位数才可能增大。不小于 \(mid\) 的数的个数 比小于 \(mid\) 的数的个数越多,贡献就越大。

下文将一个区间中不小于 \(mid\) 的个数与小于 \(mid\) 的个数 的差值称为该区间的贡献。


之后怎么做?动态维护中位数,找到左右两侧对 当前中位数 贡献最大的前后缀加入?发现不好搞,因为中位数变化后,贡献也在变化。

考虑枚举固定一个中位数的 下界,从而固定前后缀的贡献。以此下界为据,加入左右两侧贡献最大的前后缀。判断最后的中位数是否满足大于该下界。
若满足,说明答案 \(\ge\) 该下界,可调高下界,否则需要调低下界。发现该 中位数的下界 满足单调性,可二分答案枚举 \(mid\)

发现选择 1 个不小于 \(mid\) 的数 和 1 个小于 \(mid\) 的数贡献会相互抵消。更形象地,1 个不小于 \(mid\) 的数的 贡献 为 1,选择 1 个小于 \(mid\) 的数的 代价 也为 1。由此,枚举下界后,考虑条件转换:将区间 \([a,d]\) 内不小于 \(mid\) 的数变为 1,小于 \(mid\) 的数变为 -1。此时一段区间的贡献变为该区间转化后的和。

单次 Check 需要解决的问题变为:判断转化后的数列,\([b,c]\) 的和 + \([a,b)\) 的最大后缀和 + \((c,d]\) 的最大前缀和,是否不小于 \(0\)。暴力实现单次 Check 的复杂度为 \(O(n)\),总复杂度为 \(O(qn\log n)\),过不了。


地球人都知道最大前后缀和可以用线段树维护:SP1043 GSS1 - Can you answer these queries I

考虑预处理每一个二分的下界对应的数列,并建出线段树。这样单次 Check 的复杂度变为 \(O(\log n)\),询问的复杂度变为 \(O(q\log^2 n)\)
但预处理 \(n\) 棵线段树空间,时间都是 \(O(n^2 \log n)\) 的,还是很菜,过不了。

下面是一个很神的优化 (不愧是clj)
考虑二分答案枚举的下界 \(mid\),发现对于下界为 \(mid\)\(mid+1\) 时的数列,仅有 \(a_i = mid\) 的位置不同,由 1 变为了 -1。再观察上面建的 \(n\) 棵线段树,由上可知,对于第 \(mid\) 棵与第 \(mid+1\) 棵,仅有\(a_i = mid\) 的位置不同。
考虑对 二分值 进行可持久化,从而压缩线段树空间。 离散化后二分值最多仅有 \(n\) 种取值,预处理的空间时间均压缩为 \(O(n\log n)\) 级别。再套用上面的查询,总复杂度为 \(O(n\log n + q\log^2 n)\),可过。

代码

注意可持久化线段树的实际含义。
root 的下标是 二分值,因此构建 root[i] 时,修改对象为 \(a_j<i\) 的位置。

100pts

//知识点:二分答案,可持久化线段树 
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <vector>
#define ll long long
const int kMaxn = 2e4 + 10;
//=============================================================
struct Node {
  int sum, lmax, rmax;
};
int n, q, ans, a[kMaxn], root[kMaxn];
int data_num, data[kMaxn], map[kMaxn];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
namespace Seg {
  #define ls (lson[now_])
  #define rs (rson[now_])
  #define mid ((L_+R_)>>1)
  int node_num, lson[kMaxn << 5], rson[kMaxn << 5];
  int sum[kMaxn << 5], lmax[kMaxn << 5], rmax[kMaxn << 5];
  void Debug(int now_, int L_, int R_) {
    printf("%d %d:%d %d %d\n", L_, R_, sum[now_], lmax[now_], rmax[now_]);
    if (L_ == R_) {
      return ;
    }
    Debug(ls, L_, mid);
    Debug(rs, mid + 1, R_);
  }
  void Pushup(int now_) {
    sum[now_] = sum[ls] + sum[rs];
    lmax[now_] = std::max(lmax[ls], sum[ls] + lmax[rs]);
    rmax[now_] = std::max(rmax[ls] + sum[rs], rmax[rs]);
  }
  void Build(int &now_, int L_, int R_) {
    now_ = ++ node_num;
    if (L_ == R_) {
      lmax[now_] = rmax[now_] = sum[now_] = 1;
      return ;
    }
    Build(ls, L_, mid);
    Build(rs, mid + 1, R_);
    Pushup(now_);
  }
  void Modify(int &now_, int pre_, int L_, int R_, int pos_) {
    now_ = ++ node_num;
    ls = lson[pre_], rs = rson[pre_];
    if (L_ == R_) {
      lmax[now_] = rmax[now_] = 0;
      sum[now_] = -1;
      return ;
    }
    if (pos_ <= mid) Modify(ls, lson[pre_], L_, mid, pos_);
    else Modify(rs, rson[pre_], mid + 1, R_, pos_);
    Pushup(now_);
  }
  Node Merge(Node l_, Node r_) {
    Node ret = (Node) {0, 0, 0};
    ret.sum = l_.sum + r_.sum;
    ret.lmax = std::max(l_.lmax, l_.sum + r_.lmax);
    ret.rmax = std::max(l_.rmax + r_.sum, r_.rmax);
    return ret;
  }
  Node Query(int now_, int L_, int R_, int l_, int r_) {
    if (l_ > r_) return (Node) {0, 0, 0};
    if (l_ <= L_ && R_ <= r_) {
      return (Node) {sum[now_], lmax[now_], rmax[now_]};
    }
    if (r_ <= mid) return Query(ls, L_, mid, l_, r_); 
    if (l_ > mid) return Query(rs, mid + 1, R_, l_, r_);
    return Merge(Query(ls, L_, mid, l_, mid), Query(rs, mid + 1, R_, mid + 1, r_));
  }
  #undef ls
  #undef rs
  #undef mid 
}
void Prepare() {
  n = read();
  for (int i = 1; i <= n; ++ i) a[i] = data[i] = read();
  
  std::sort(data + 1, data + n + 1);
  data_num = 1;
  map[1] = data[1];
  for (int i = 2; i <= n; ++ i) {
    if (data[i] != data[i - 1]) {
      map[++ data_num] = data[i];
    }
    data[data_num] = data[i]; 
  }
  
  std :: vector <int> pos[kMaxn];
  for (int i = 1; i <= n; ++ i) {
    a[i] = std::lower_bound(data + 1, data + data_num + 1, a[i]) - data;
    pos[a[i]].push_back(i);
  }
  
  Seg::Build(root[1], 1, n);
  for (int i = 2; i <= data_num; ++ i) {
    root[i] = root[i - 1];
    for (int j = 0, lim = pos[i - 1].size(); j < lim; ++ j) {
      Seg::Modify(root[i], root[i], 1, n, pos[i - 1][j]);
    }
  }
}
bool Check(int ll_, int lr_, int rl_, int rr_, int lim_) {
  int sum = Seg::Query(root[lim_], 1, n, lr_, rl_).sum;
  int suml = Seg::Query(root[lim_], 1, n, ll_, lr_ - 1).rmax;
  int sumr = Seg::Query(root[lim_], 1, n, rl_ + 1, rr_).lmax;
  return (sum + suml + sumr) >= 0;
}
void Query(int ll_, int lr_, int rl_, int rr_) {
  for (int l = 1, r = data_num; l <= r; ) {
    int mid = ((l + r) >> 1);
    if (Check(ll_, lr_, rl_, rr_, mid)) {
      ans = map[mid];
      l = mid + 1;
    } else {
      r = mid - 1;
    }
  }
}
//=============================================================
int main() {
  Prepare();
  q = read();
  for (int i = 1; i <= q; ++ i) {
    int opt[4];
    for (int j = 0; j < 4; ++ j) {
      opt[j] = (read() + ans) % n;
    }
    std::sort(opt, opt + 4);
    Query(opt[0] + 1, opt[1] + 1, opt[2] + 1, opt[3] + 1);
    printf("%d\n", ans);
  }
  return 0;
}

30pts

//知识点:二分答案
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#define ll long long
const int kMaxn = 2e4 + 10;
//=============================================================
int n, q, ans, a[kMaxn], tmp[kMaxn];
int data_num, data[kMaxn], map[kMaxn];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
void Prepare() {
  n = read();
  for (int i = 1; i <= n; ++ i) a[i] = data[i] = read();
  
  std :: sort(data + 1, data + n + 1);
  data_num = 1;
  map[1] = data[1];
  for (int i = 2; i <= n; ++ i) {
    if (data[i] != data[i - 1]) {
      map[++ data_num] = data[i];
    }
    data[data_num] = data[i]; 
  }
  for (int i = 1; i <= n; ++ i) {
    a[i] = std :: lower_bound(data + 1, data + data_num + 1, a[i]) - data;
  }
}
bool Check(int ll_, int lr_, int rl_, int rr_, int lim_) {
  int sum = 0, suml = 0, sumr = 0;
  for (int i = lr_; i <= rl_; ++ i) {
    sum += (a[i] >= lim_ ? 1 : - 1); 
  }
  for (int i = lr_ - 1, sumnow = 0; i >= ll_; -- i) {
    sumnow += (a[i] >= lim_ ? 1 : - 1);
    Chkmax(suml, sumnow);
  }
  for (int i = rl_ + 1, sumnow = 0; i <= rr_; ++ i) {
    sumnow += (a[i] >= lim_ ? 1 : - 1);
    Chkmax(sumr, sumnow);
  }
  return (sum + suml + sumr) >= 0;
}
void Query(int ll_, int lr_, int rl_, int rr_) {
  for (int l = 1, r = data_num; l <= r; ) {
    int mid = ((l + r) >> 1);
    if (Check(ll_, lr_, rl_, rr_, mid)) {
      ans = map[mid];
      l = mid + 1;
    } else {
      r = mid - 1;
    }
  }
}
//=============================================================
int main() {
  Prepare();
  q = read();
  for (int i = 1; i <= q; ++ i) {
    int opt[4];
    for (int j = 0; j < 4; ++ j) {
      opt[j] = (read() + ans) % n;
    }
    std :: sort(opt, opt + 4);
    Query(opt[0] + 1, opt[1] + 1, opt[2] + 1, opt[3] + 1);
    printf("%d\n", ans);
  }
  return 0;
}

推荐阅读