首页 > 技术文章 > wqs 二分学习笔记

Point-King 2021-05-08 15:02 原文

正好模拟赛有一道关于这个的题目,我又恰好不会,然后就学学看。

发现一句比较重要的话:

然后如果这题没有刚好选 \(m\) 个的限制的时候就可以 \(\text{dp}\) 降维的话,那么就可以考虑一下 \(\text{WQS}\) 二分。

来考虑一道我很早以前就发现的 \(\text{wqs}\) 二分的题(但是我一直没做)。

例一:P2619 [国家集训队]Tree I

给你一个无向带权连通图,每条边是黑色或白色。让你求一棵最小权的恰好有 \(need\) 条白色边的生成树。

我们考虑没有白色边的限制,就必然是直接跑一个最小生成树就可以了,关键是如何二分这个凸函数的斜率和进行计算。

我来思考一下。

我们首先定义每一个物品的价值为 \(v_i\) ,同时定义凸函数 \(f(x)\) 表示在选择 \(x\)\(v_i\) 的情况下最小的权值。他是下凸的。

然后我们考虑二分这个斜率 \(k\) ,同时为每一个 \(v_i\) 减去 \(k\) ,使得求出过每一个 \((x,f(x))\) 的斜率为 \(k\) 的线的截距,而截距最大的那个点必然就是切点了。

好像懂了。


不,我没懂。

我不知道如何处理斜率 \(k\)\(v_i\) 的关系,因为这里是生成树,如果我所有边都减小相同的数值,答案必然不会变。哦,我知道了,只减去白色边的。

代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5,M=1e5+5;
int n,m,k;
struct Edge{int from,to,val,col;}e[M],E[M];
bool cmp(Edge a,Edge b){return a.val==b.val?a.col<b.col:a.val<b.val;}
struct DSU{
	int fa[N];
	void init(int n){for(int i=1;i<=n;++i)fa[i]=i;}
	int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
	bool merge(int u,int v){
		int fu=find(u),fv=find(v);
		if(fu!=fv) return fa[fv]=fu,true;
		return false;
	}
}d;
pair<int,int> cal(int tmp){
	int sum=0,cnt=0;
	d.init(n);
	for(int i=1;i<=m;++i)
	E[i]=e[i],E[i].val-=(!E[i].col)*tmp;
	sort(E+1,E+1+m,cmp);
	for(int i=1;i<=m;++i){
		// printf("---%d %d %d %d\n",E[i].from,E[i].to,E[i].val,E[i].col);
		if(d.merge(E[i].from,E[i].to)) sum+=E[i].val,cnt+=(!E[i].col);
	}
	return make_pair(cnt,sum);
}
int res=-1;
int main(){
	cin>>n>>m>>k;
	for(int i=1;i<=m;++i){
		scanf("%d%d%d%d",&e[i].from,&e[i].to,&e[i].val,&e[i].col);
		e[i].from++,e[i].to++;
	}
	int L=-100,R=100;
	while(L<=R){
		int Mid=(L+R)>>1;pair<int,int> tmp=cal(Mid);
		// printf("%d %d %d\n",Mid,tmp.first,tmp.second);
		if(tmp.first>=k) res=tmp.second+Mid*k,R=Mid-1;
		else L=Mid+1;
	}
	return printf("%d\n",res),0;
}

例二:P5308 [COCI2019] Quiz

不要问我为什么例二是黑,因为我也不知道,反正是别人博客里的第一道例题,我就试试看呗。

盲猜一个一个淘汰必然是最优的,但是限制了轮数,所以使用 \(\text{wqs}\) 二分?

通过 \(\text{dp}\) 来思考二分过程 。

我们用 \(f_{i,j}\) 来表示到达第 \(j\) 轮死了 \(i\) 个人的最多奖金。

\[f_{i,j}=\max_k f_{i-k,j-1}+\frac{k}{n-i+k} \]

感觉我缺少一个贪心结论,导致我不能直接搞 。

我有一个猜测,就是在第一轮直接将人数和剩余轮数持平,然后后面直接搞一个一个淘汰。。。

猜错了。


看了一下题解学习了一下。

就是 \(\text{dp}\) 的方程如果是倒序思考的话,就是只需要一维即可, \(f_i\) 表示还剩下 \(i\) 人时最大奖金。

\[f_i=\max_{0\le j<i}(f_j+\frac{i-j}{i}) \]

然后考虑找到这个最大点的过程我们使用 \(\text{wqs}\) 二分。

等等等等,我们发现这个过程中我们没有限制轮数啊。

哦,我懂了,就是我们考虑将上面的 \(\text{dp}\) 式子优化成 \(O(n)\) 然后再套一个 \(\text{wqs}\) 二分?

应该是这样的。

考虑如何优化 \(\text{dp}\)

\[f_{i}\cdot i=\max_{0\le j<i}(f_j\cdot i+i-j)\\ f_i\cdot i=f_j\cdot i+i-j\\ (f_j+1)\cdot i-f_i\cdot i=j\\ \]

维护点 \((f_j+1,j)\) 的下凸包,以斜率 \(i\) 去切这个凸包即可。

但是这个会了,如何将 \(\text{wqs}\) 二分结合进去呢?

必然的,凸函数 \(f(x)\) 表示 \(x\) 轮结束的情况下,最多的奖金,其必然是一个单调增的,且斜率单调减的函数(我不会证,但我觉得显然)。

考虑我这里用斜率去切的话相当于是在后面的式子里加一个 \(-k\) 即可,表示每一次转移会少这么多。

\[f_i=\max_{0\le j<i}(f_j+\frac{i-j}{i}-k)\\ f_i\cdot i=\max_{0\le j<i}(f_j\cdot i+i-j-ik)\\ f_i\cdot i=f_j\cdot i+i-j-ik\\ (f_j+1-k)\cdot i-f_{i}\cdot i=j \]

维护点 \((f_j+1-k,j)\) 的下凸包即可。

代码
#include<bits/stdc++.h>
using namespace std;
#define double long double
const int N=1e5+5;
int n,k;
struct Point{double x,y;};
struct Vector{double x,y;};
Vector operator - (Point a,Point b){return (Vector){a.x-b.x,a.y-b.y};}
double operator * (Vector a,Vector b){return a.x*b.y-a.y*b.x;}
double res=0,f[N];int g[N];
pair<Point,int> bag[N];
pair<int,double> cal(double k){
	int L=1,R=0;
	memset(f,0,sizeof(f)),memset(g,0,sizeof(g));
	f[1]=g[1]=1,bag[++R]=make_pair((Point){f[1]+1-k,1},g[1]);
	for(int i=2;i<=n;++i){
		while(R-L>0&&(bag[L+1].first-bag[L].first)*(Vector){1,1.0*i}>0) L++;
		f[i]=bag[L].first.x-bag[L].first.y/i,g[i]=bag[L].second+1;
		Point tmp=(Point){f[i]+1-k,1.0*i};
		while(R-L>0&&(bag[R].first-bag[R-1].first)*(tmp-bag[R-1].first)<0) R--;
		bag[++R]=make_pair(tmp,g[i]);
	}
	// for(int i=1;i<=n;++i) printf("%.9lf %d\n",f[i],g[i]);
	return make_pair(g[n],f[n]);
}
int main(){
	cin>>n>>k;
	double L=0,R=1;
	while((R-L)>1e-16){
		double Mid=(L+R)/2;pair<int,double> tmp=cal(Mid);
		// printf("%.9lf %d %.9lf\n",Mid,tmp.first,tmp.second);
		if(tmp.first>=k) res=tmp.second+Mid*(k-1),L=Mid;
		else R=Mid;
	}
	return printf("%.9Lf\n",res),0;
}

例三:CF739E Gosha is hunting

这道题目好像有点不太一样,他是存在两个限制,可能需要二分套二分?

还是从 \(\text{dp}\) 出发, \(f_{i,j,k}\) 表示到第 \(i\) 个神奇宝贝时,普通宝贝球用了 \(j\) 个,超级宝贝球用了 \(k\) 个时捕捉数最大的期望,转移即

\[f_{i,j,k}=\max\left\{\begin{matrix} f_{i-1,j,k}\\ f_{i-1,j-1,k}+p_i\\ f_{i-1,j,k-1}+u_i\\ f_{i-1,j-1,k-1}+1-(1-p_i)(1-u_i) \end{matrix}\right. \]

这个是 \(O(n^3)\) 的,需要优化,实际上发现不需要二分套二分了,只需要用 \(\text{wqs}\) 二分消去其中的一维即可。

我们令普通宝贝球没有限制,且当前其转移斜率为 \(k\) ,则 \(f_{i,j}\) 表示到第 \(i\) 个神奇宝贝,超级宝贝球用了 \(j\) 个时的最大期望,转移即

\[f_{i,j}=\max\left\{\begin{matrix} f_{i-1,j}\\ f_{i-1,j}+p_i-k\\ f_{i-1,j-1}+u_i\\ f_{i-1,j-1}+1-(1-p_i)(1-u_i)-k \end{matrix}\right. \]

说实话感觉怪怪的,因为之前的 \(\text{wqs}\) 二分好像没有出现过乘的情况,不知道可不可以处理。感觉和上面的黑白边有一点相似。

代码
#include<bits/stdc++.h>
using namespace std;
const int N=2e3+5;
int n,x,y;
double a[N],b[N];
int g[N][N];double f[N][N],res=0;
pair<int,double> cal(double k){
	f[0][0]=0,g[0][0]=0;
	for(int i=1;i<=n;++i){
		for(int j=0;j<=y;++j){
			double tmp1=0,tmp2=a[i]-k;
			if(tmp2>=tmp1) f[i][j]=f[i-1][j]+tmp2,g[i][j]=g[i-1][j]+1;
			else f[i][j]=f[i-1][j]+tmp1,g[i][j]=g[i-1][j];
			if(j){
				tmp1=b[i],tmp2=1-(1-a[i])*(1-b[i])-k;
				if(tmp2>=tmp1&&f[i-1][j-1]+tmp2>f[i][j])
				f[i][j]=f[i-1][j-1]+tmp2,g[i][j]=g[i-1][j-1]+1;
				if(tmp2<tmp1&&f[i-1][j-1]+tmp1>f[i][j])
				f[i][j]=f[i-1][j-1]+tmp1,g[i][j]=g[i-1][j-1];
			}
		}
	}
	return make_pair(g[n][y],f[n][y]);
}
int main(){
	cin>>n>>x>>y;
	for(int i=1;i<=n;++i) scanf("%lf",&a[i]);
	for(int i=1;i<=n;++i) scanf("%lf",&b[i]);
	double L=-1,R=1;
	while(R-L>1e-8){
		double Mid=(L+R)/2;pair<int,double> tmp=cal(Mid);
		if(tmp.first>=x) res=tmp.second+Mid*x,L=Mid;
		else R=Mid;
	}
	return printf("%.5lf\n",res),0;
}

闵可夫斯基和

发现这个东西和 \(\text{wqs}\) 二分结合挺多,跟其他的东西结合也很巧妙,本文这里就将一个闵可夫斯基和和 \(\text{wqs}\) 二分的结合例子。

概念介绍

自己 \(\text{google}\) 。你可以发现闵可夫斯基和实际上就是将两个凸包的斜率进行排序之后组合成一个大凸包。

为什么两者可以结合

你发现 \(\text{wqs}\) 二分出来的凸函数是可以用闵可夫斯基和来合并的,比如子问题的答案实际上就是可以通过闵可夫斯基和的合并来得到父问题的答案。

例四:GYM102331G Honorable Mention

这里发现对于每一个区间的子问题都是满足相对于取的个数为凸函数。

然后我们可以考虑在线段树上搞区间的凸函数合并,这样的话对于每一个子问题,我们就可以通过提取出对应这个区间的 \(\log_2n\) 个凸函数来计算答案。

对于这 \(\log_2n\) 个区间继续使用闵可夫斯基和来进行求解明显是不现实的,我们则可以使用 \(\text{wqs}\) 二分来降低复杂度,总复杂度是 \(O(n\log_2n+q\log_2^2n\log_2V)\) 的。

代码
#include<bits/stdc++.h>
using namespace std;
const int N=3.5e4+5;
const int INF=2147483647;
int n,q,a[N];
struct Convex_Hull{
	vector<int> f;int L,R;
	int &operator [] (int x){return f[x];}
	void resize(int n){
		while((int)f.size()>n) f.pop_back();
		while((int)f.size()<n) f.push_back(-INF);
	}
	void limit(int l,int r){L=l,R=r,resize(r+1);}
};
Convex_Hull operator + (Convex_Hull a,Convex_Hull b){
	if(a.L>a.R) return a;
	if(b.L>b.R) return b;
	Convex_Hull res;res.limit(a.L+b.L,a.R+b.R);
	int i=1,j=1;res[res.L]=a[a.L]+b[b.L];
	while(i+a.L<=a.R&&j+b.L<=b.R){
		if(a[i+a.L]-a[i+a.L-1]>b[j+b.L]-b[j+b.L-1])
		res[res.L+i+j-1]=res[res.L+i+j-2]+a[i+a.L]-a[i+a.L-1],i++;
		else res[res.L+i+j-1]=res[res.L+i+j-2]+b[j+b.L]-b[j+b.L-1],j++;
	}
	while(i+a.L<=a.R) res[res.L+i+j-1]=res[res.L+i+j-2]+a[i+a.L]-a[i+a.L-1],i++;
	while(j+b.L<=b.R) res[res.L+i+j-1]=res[res.L+i+j-2]+b[j+b.L]-b[j+b.L-1],j++;
	return res;
}
Convex_Hull max(Convex_Hull a,Convex_Hull b){
	if(a.L>a.R) return b;
	if(b.L>b.R) return a;
	Convex_Hull res;res.limit(min(a.L,b.L),max(a.R,b.R));
	for(int i=res.L;i<=res.R;++i){
		if(a.L<=i&&i<=a.R) res[i]=max(res[i],a[i]);
		if(b.L<=i&&i<=b.R) res[i]=max(res[i],b[i]);
	}
	return res;
}
struct Seg_Tree{
	struct Node{Convex_Hull data[2][2];}tr[N<<2];
	void up(int u){
		Convex_Hull tmp;
		for(int i=0;i<2;++i){
			for(int j=0;j<2;++j)
			tr[u].data[i][j].limit(1,0);
		}
		for(int i=0;i<2;++i)
		for(int j=0;j<2;++j)
		for(int c=0;c<2;++c)
		for(int d=0;d<2;++d){
			tmp=tr[u<<1].data[i][j]+tr[u<<1|1].data[c][d];
			tr[u].data[i][d]=max(tr[u].data[i][d],tmp);
			if(j&&c){
				tmp.L--,tmp.R--;
				for(int e=tmp.L;e<=tmp.R;++e) tmp[e]=tmp[e+1];
				tr[u].data[i][d]=max(tr[u].data[i][d],tmp);
			}
		}
	}
	void build(int u,int l,int r,int a[]){
		if(l==r){
			tr[u].data[0][0].limit(0,0);
			tr[u].data[0][0][0]=0;
			tr[u].data[0][1].limit(1,0);
			tr[u].data[1][0].limit(1,0);
			tr[u].data[1][1].limit(1,1);
			tr[u].data[1][1][1]=a[l];
			return void();
		}
		int mid=(l+r)>>1;
		build(u<<1,l,mid,a);
		build(u<<1|1,mid+1,r,a);
		return up(u);
	}
	void query(int u,int l,int r,int x,int y,vector<Convex_Hull*> bag[][2]){
		if(x<=l&&r<=y){
			bag[0][0].push_back(&tr[u].data[0][0]);
			bag[0][1].push_back(&tr[u].data[0][1]);
			bag[1][0].push_back(&tr[u].data[1][0]);
			bag[1][1].push_back(&tr[u].data[1][1]);
			return ;
		}
		int mid=(l+r)>>1;
		if(x<=mid) query(u<<1,l,mid,x,y,bag);
		if(y>mid) query(u<<1|1,mid+1,r,x,y,bag);
		return ;
	}
}t;
struct Data{long long data;int cnt;};
bool operator < (Data a,Data b){return a.data==b.data?a.cnt>b.cnt:a.data<b.data;}
Data operator + (Data a,Data b){return (Data){a.data+b.data,a.cnt+b.cnt};}
Data search(int k,Convex_Hull* a){
	if(a->L>a->R) return (Data){-INF,INF};
	int l=a->L,r=a->R-1,res=a->R;
	while(l<=r){
		int mid=(l+r)>>1;
		if(a->f[mid+1]-a->f[mid]<=k) r=mid-1,res=mid;
		else l=mid+1;
	}
	return (Data){a->f[res]-1ll*res*k,res};
}
Data cal(int k,vector<Convex_Hull*> bag[][2]){
	int m=bag[0][0].size();
	Data f[30][2];
	for(int i=1;i<m;++i) f[i][0]=f[i][1]=(Data){-INF,INF};
	f[0][0]=max(search(k,bag[0][0][0]),search(k,bag[1][0][0]));
	f[0][1]=max(search(k,bag[0][1][0]),search(k,bag[1][1][0]));
	for(int i=1;i<m;++i){
		Data tmp;
		for(int c=0;c<2;++c){
			for(int d=0;d<2;++d){
				tmp=search(k,bag[c][d][i]);
				for(int j=0;j<2;++j){
					f[i][d]=max(f[i][d],f[i-1][j]+tmp);
					if(j&&c){
						tmp.data+=k,tmp.cnt--;
						f[i][d]=max(f[i][d],f[i-1][j]+tmp);
						tmp.data-=k,tmp.cnt++;
					}
				}
			}
		}
	}
	return max(f[m-1][0],f[m-1][1]);
}
signed main(){
	// freopen("ex_inaugurate5.in","r",stdin);
	// freopen("inaugurate.out","w",stdout);
	// cin>>n;
	cin>>n>>q;
	for(int i=1;i<=n;++i) scanf("%d",&a[i]);
	t.build(1,1,n,a);
	for(int i=1;i<=q;++i){
		int l,r,k;vector<Convex_Hull*> bag[2][2];
		scanf("%d%d%d",&l,&r,&k),t.query(1,1,n,l,r,bag);
		int L=-N*N,R=N*N,res1=-1;Data res2=(Data){0ll,0};
		while(L<=R){
			int Mid=(1ll+L+R)>>1;Data tmp=cal(Mid,bag);
			if(tmp.cnt<=k) R=Mid-1,res1=Mid,res2=tmp;else L=Mid+1;
		}
		printf("%lld\n",res2.data+1ll*k*res1);
	}
}

推荐阅读