首页 > 技术文章 > P4370 [Code+#4]组合数问题2

luckyblock 2020-10-17 21:43 原文

知识点:堆,小技巧

原题面:Luogu

简述

定义两个组合数 \(a_i\choose b_i\)\(a_j\choose b_j\) 不同,当且仅当 \(a_i\not = a_j\)\(b_i\not ={b_j}\)

给定参数 \(n,k\),要求选出 \(k\) 个不同的组合数 \(a_i\choose b_i\),满足 \(0\le b_i\le a_i\le n\),最大化它们的和。
输出它们的和 \(\bmod 10^9+ 7\) 的值。
\(1\le n\le 10^6\)\(1\le k\le 10^5\)

分析

显然答案即为前 \(k\) 大的组合数的和。
显然最大的组合数为 \(n\choose \frac{n}{2}\),它一定需要被选择。

考虑次大的组合数的位置,显然只可能出现在以下四个位置:
\(n\choose \frac{n}{2}-1\)\(n\choose \frac{n}{2}+1\)\(n-1\choose \frac{n}{2}-1\)\(n-1\choose \frac{n}{2}\)

然后可以发现更一般的规律,比组合数 \(a\choose b\) 小的最大的数只能出现在下列四个位置:
\(a\choose b-1\)\(a\choose b+1\)\(a-1\choose b-1\)\(a-1\choose b\)

考虑使用元素降序的优先队列进行维护,初始时队列中仅有 \(n\choose \frac{n}{2}\)
每次取出队首元素,将其加入答案,枚举四个次小位置加入队列,注意去重
取出 \(k\) 个元素后即得答案。

还有个问题,组合数会很大,如何定义优先级比较方式。
考虑将组合数拆成下降幂,发现下面三个命题,互为充要条件:

\[\begin{aligned} \prod_{i=a+1}^{b}i &> \prod_{i=c+1}^{d} i\\ \log \prod_{i=a+1}^{b} i &> \log \prod_{i=c+1}^{d} i\\ \sum_{i=a+1}^{b}\log i &> \sum_{i=c+1}^{d}\log i \end{aligned}\]

预处理前缀 \(\log\) 值的和(即阶乘的 \(\log\) 值)即可比较两个下降幂的大小。
注意 priority_queue 奇怪的重载优先级。

代码

//知识点:堆,小技巧 
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <queue>
#include <set>
#define LL long long
#define pr std::pair
#define mp std::make_pair
const LL kMod = 1e9 + 7;
const int kMaxn = 1e6 + 10;
const int en[5] = {0, 0, -1, -1};
const int em[5] = {-1, 1, 0, -1};
//=============================================================
struct Data {
  int n, m;
  double val; 
  //注意奇怪的重载 
  bool operator < (const Data &sec) const {
    return val < sec.val;
  }
};
int n, k;
LL ans, fac[kMaxn];
double LogFac[kMaxn];
std::priority_queue <Data> q;
std::set <pr <int, int> > In_queue, Hash;
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
LL QuickPow(LL x_, LL y_, LL mod_) {
  LL ret = 1;
  for (; y_; y_ >>= 1) {
    if (y_ & 1) ret = ret * x_ % mod_;
    x_ = x_ * x_ % mod_;
  }
  return ret;
}
double LogC(int n_, int m_) {
  return LogFac[n_] - LogFac[m_] - LogFac[n_ - m_];
}
LL C(int n_, int m_) {
  if (n_ == m_ || m_ == 0) return 1ll;
  return fac[n_] * 
         QuickPow(fac[m_], kMod - 2, kMod) % kMod * 
         QuickPow(fac[n_ - m_], kMod - 2, kMod) % kMod;
}
//=============================================================
int main() {
  n = read(), k = read();
  fac[0] = 1;
  for (int i = 1; i <= n; ++ i) {
    fac[i] = 1ll * fac[i - 1] * i % kMod;
    LogFac[i] = LogFac[i - 1] + log(1.0 * i);
  }
  q.push((Data) {n, n / 2, LogC(n, n / 2)});
  In_queue.insert(mp(n, n / 2));
  for (int i = 1; i <= k; ++ i) {
    if (q.empty()) break; 
    Data top = q.top(); 
    q.pop();
    Hash.insert(mp(top.n, top.m));
    In_queue.erase(mp(top.n, top.m));
    ans = (ans + C(top.n, top.m)) % kMod;
    
    for (int j = 0; j < 4; ++ j) {
      int newn = top.n + en[j];
      int newm = top.m + em[j];
      if (newn < 0 || newm < 0 || newn < newm) continue ;
      if (Hash.count(mp(newn, newm))) continue ;
      if (In_queue.count(mp(newn, newm))) continue ;
      In_queue.insert(mp(newn, newm));
      q.push((Data) {newn, newm, LogC(newn, newm)});
      Data now = q.top();
    }
  }
  printf("%lld\n", ans);
  return 0;
}

/*
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <queue>
#define LL long long
const LL kMod = 1e9 + 7;
const int kMaxn = 1e6 + 10;
//=============================================================
struct Data {
  int n, m;
  double val;
  bool operator < (const Data &sec) const {
    return val < sec.val;
  }
};
int n, k;
LL ans, fac[kMaxn];
double LogFac[kMaxn];
std::priority_queue <Data> q;
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
LL QuickPow(LL x_, LL y_, LL mod_) {
  LL ret = 1;
  for (; y_; y_ >>= 1) {
    if (y_ & 1) ret = ret * x_ % mod_;
    x_ = x_ * x_ % mod_;
  }
  return ret;
}
double LogC(int n_, int m_) {
  return LogFac[n_] - LogFac[m_] - LogFac[n_ - m_];
}
LL C(int n_, int m_) {
  return fac[n_] * 
         QuickPow(fac[m_], kMod - 2, kMod) % kMod * 
         QuickPow(fac[n_ - m_], kMod - 2, kMod) % kMod;
}
//=============================================================
int main() {
  n = read(), k = read();
  fac[0] = 1;
  for (int i = 1; i <= n; ++ i) {
    fac[i] = 1ll * fac[i - 1] * i % kMod;
    LogFac[i] = LogFac[i - 1] + log2(1.0 * i);
  }
  for (int i = 0; i <= n; ++ i) {
    q.push((Data) {n, i, LogC(n, i)});
  }
  for (int i = 1; i <= k; ++ i) {
    Data top = q.top(); q.pop();
    ans = (ans + C(top.n, top.m)) % kMod;
    q.push((Data) {top.n - 1, top.m, LogC(top.n - 1, top.m)});
  }
  printf("%lld\n", ans);
  return 0;
}
*/ 

推荐阅读