首页 > 技术文章 > IOI2021集训队作业227CB Branch Assignment

jz-597 2020-11-26 20:21 原文

核心问题:给出长度为\(b\)的数组\(c_i\),把每个数分到恰好\(s\)个集合中,最小化\(\sum c_i|c_i所在集合|\)

\(b,s\le 5000\)


考虑最终状态下的两个集合\(S,T\),假定\(|S|\le|T|\),如果有\(x\in S,y\in T\),如果\(x< y\),则把\(x,y\)交换一定不会更劣。

于是小的数放到大集合,大的数放到小的集合;除此外可以推广到:给\(c_i\)排序,那么一个集合一定对应着一段区间。

DP设\(f_{i,j}\)表示搞了前\(i\)个数,选了\(j\)个集合,可以\(O(n^3)\)做。

发现\(f_{i,j}\)关于\(j\)是凸函数,于是可以凸优化。时间\(O(n^2\lg)\)

似乎也可以决策单调性做。(类似这题?看着转移方程挺像的)

虽然会口胡但是之前从来就没有写过凸优化,这次是第一次写,一开始二分的斜率还取的是实数……凸优化二分的斜率不需要取实数的……


using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <cassert>
#include <cmath>
#define N 5005
#define M 50005
#define ll long long
#define fi first
#define se second
#define INF 1000000000000000000
int n,m,B,S;
struct EDGE{
	int to,w;
	EDGE *las;
};
struct Graph{
	EDGE e[M];
	int ne;
	EDGE *last[N];
	void link(int u,int v,int w){
		e[ne]={v,w,last[u]};
		last[u]=e+ne++;
	}
} G0,G1;
ll c0[N],c1[N],c[N];
void S_P(Graph &G,ll dis[]){
	static bool inq[N];
	static queue<int> q;
	memset(dis,127,sizeof(ll)*(n+1));
	dis[B+1]=0;
	q.push(B+1),inq[B+1]=1;
	while (!q.empty()){
		int x=q.front();
		q.pop();
		for (EDGE *ei=G.last[x];ei;ei=ei->las)
			if (dis[x]+ei->w<dis[ei->to]){
				dis[ei->to]=dis[x]+ei->w;
				if (!inq[ei->to])
					q.push(ei->to),inq[ei->to]=1;
			}
		inq[x]=0;
	}
}
ll ps[N];
pair<ll,int> f[N];
inline void dp(ll w){
	f[0]={0,0};
	for (register int i=1;i<=B;++i){
		f[i]={INF,0};
		for (register int k=0;k<i;++k){
			ll tmp=f[k].fi+(i-k)*(ps[i]-ps[k]);
			if (tmp<f[i].fi)
				f[i].fi=tmp,f[i].se=f[k].se+1;
		}
		f[i].fi-=w;
	}
}
int main(){
//	freopen("in.txt","r",stdin);
//	freopen("out.txt","w",stdout);
	scanf("%d%d%d%d",&n,&B,&S,&m);
	for (int i=1;i<=m;++i){
		int u,v,w;
		scanf("%d%d%d",&u,&v,&w);
		G0.link(u,v,w);
		G1.link(v,u,w);
	}
	S_P(G0,c0);
	S_P(G1,c1);
	for (int i=1;i<=B;++i)
		c[i]=c0[i]+c1[i];
//	for (int i=1;i<=B;++i)
//		scanf("%lld",&c[i]);
	sort(c+1,c+B+1);
	for (int i=1;i<=B;++i)
		ps[i]=ps[i-1]+c[i];
	ll l=-ps[B]*B,r=0,res=0;
	while (l<=r){
		ll mid=l+r>>1;
		dp(mid);
		if (f[B].se<=S)
			l=(res=mid)+1;
		else
			r=mid-1;
	}
	dp(res);
	ll ans=f[B].fi+res*S;
	printf("%lld\n",ans-ps[B]);
	return 0;
}

推荐阅读