首页 > 技术文章 > 6846. 【2020.11.02提高组模拟】旅人1970

jz-597 2020-11-02 21:33 原文

给出一个可重集,\(a_i\)表示\(i\)的出现次数。分成若干个集合,最小化每个集合的\(2\)的最小众数次方之和。

支持修改。

\(n\le 2*10^5,1\le a_i\le 10^5\)

\(q\le 10^5\)


题目显然可以转化成:选择不可重集合\(S\),满足:\(\sum a_{s_i}\ge \max a_i\),然后使\(S\)从大往小字典序最小。

这样就可以得到\(O(nq)\)做法:找到最小的前缀,使得前缀和大于等于\(\max a_i\),然后求出多余的部分,从后往前扫能删就删。

这里有性质:删的连续段个数为\(O(\sqrt V)\)级别,\(V\)为值域。

具体证明考虑:对于一个没有删的位置\(i\),一定有\(a_i>\sum_{s_j<i}a_{s_j}\)。把没有删的段提出来,最坏情况下每段只有一个并且取到下界。可以发现不同段的下界是递增的(因为\(a_i\)为正数),并且没有删的数的和小于\(2V\)(否则可以再删)。于是连续段个数为\(O(\sqrt V)\)

时间复杂度\(O(q\sqrt V\lg n)\)

题解说卡不满,的确感觉上卡不满,但是为什么我吸氧后才过了?


#pragma GCC optimize("O2")
#pragma G++ optimize("O2")
using namespace std;
#include <cstdio>
#include <cstring>
#include <set>
#define N 200005
#define ll long long
#define mo 998244353
#define min(a,b) ((a)<(b)?(a):(b))
int input(){
	char ch=getchar();
	while (ch<'0' || ch>'9')
		ch=getchar();
	int x=0;
	do{
		x=x*10+ch-'0';
		ch=getchar();
	}
	while ('0'<=ch && ch<='9');
	return x;
}
int Num,n;
int pw2[N];
int a[N];
multiset<int> s;
ll sum[N*4];
int mn[N*4];
void build(int k,int l,int r){
	if (l==r){
		sum[k]=mn[k]=a[l];
		return;
	}
	int mid=l+r>>1;
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
	mn[k]=min(mn[k<<1],mn[k<<1|1]);
	sum[k]=sum[k<<1]+sum[k<<1|1];
}
void change(int k,int l,int r,int x,int c){
	if (l==r){
		sum[k]=mn[k]=c;
		return;
	}
	int mid=l+r>>1;
	if (x<=mid) change(k<<1,l,mid,x,c);
	else change(k<<1|1,mid+1,r,x,c);
	mn[k]=min(mn[k<<1],mn[k<<1|1]);
	sum[k]=sum[k<<1]+sum[k<<1|1];
}
pair<int,int> find(int k,int l,int r,int v){
	if (l==r)
		return make_pair(l,sum[k]-v);
	int mid=l+r>>1;
	if (v<=sum[k<<1]) return find(k<<1,l,mid,v);
	return find(k<<1|1,mid+1,r,v-sum[k<<1]);
}
ll res;
ll getsum(int l,int r){return pw2[r+1]-pw2[l];}
void dfs(int k,int l,int r,int en,int &s){
	if (mn[k]>s) return;
	if (r<=en && sum[k]<=s){
		s-=sum[k];
		res-=getsum(l,r);
		return;
	}
	int mid=l+r>>1;
	if (mid<en)
		dfs(k<<1|1,mid+1,r,en,s);
	dfs(k<<1,l,mid,en,s);
}
ll query(){
	int mx=*s.rbegin();
	pair<int,int> tmp=find(1,1,n,mx);
	res=getsum(1,tmp.first);
	if (tmp.first>1)
		dfs(1,1,n,tmp.first-1,tmp.second);
	res=(res%mo+mo)%mo;
	return res;
}
int main(){
//	freopen("imperishable.in","r",stdin);
//	freopen("imperishable.out","w",stdout);
	freopen("in.txt","r",stdin);
	freopen("out.txt","w",stdout);
	Num=input(),n=input();
	pw2[0]=1;
	for (int i=1;i<=n+1;++i)
		pw2[i]=pw2[i-1]*2%mo;
	for (int i=1;i<=n;++i)
		s.insert(a[i]=input());
	build(1,1,n);
	printf("%lld\n",query());
	int Q;
	scanf("%d",&Q);
	while (Q--){
		int x=input(),y=input();
		s.erase(s.find(a[x]));
		s.insert(y);
		change(1,1,n,x,a[x]=y);
		printf("%lld\n",query());
	}
	return 0;
}

推荐阅读