首页 > 技术文章 > CF1401D Maximum Distributed Tree

jd1412 2020-11-28 21:21 原文

传送门

分析

我们要最大化边权求和: \(\sum_{i=1}^{n-1}\sum_{j=i+1}^{n}{f(i,j)}\)

其中 \(f(i,j)\) 表示树上从 \(i\) 节点走到 \(j\) 节点的简单路径的边权和

那么贪心的讲,我们对于走过次数最多的边赋予最高的权值必然是最优的,给出证明:

证明

设有最大边权和次大边权 \(k_1\)\(k_2\) ,走过次数最多的边的次数和次多的边的次数 \(sum_1\)\(sum_2\)

那么若证 \(k_1 \times sum_1 + k_2 \times sum_2 > k_2 \times sum_1 + k_1 \times sum_2\)

即证 : \((k_1 - k_2) \times (sum_1 - sum_2) > 0\)

\(k_1 > k_2\)\(sum_1 > sum_2\) ,故不等式必然成立

那么我们把这个不等式类推到 \(n\) 项 , 即 :设边权从大到小为 \(k_1 , k_2 , …… , k_n\) , 边被遍历到的次数从大到小为 \(sum_1 , sum_2 , sum_3 , …… ,sum_n\)

则有:

\[k_1 \times sum_1 + k_2 \times sum_2 + …… + k_n \times sum_2 \geq ……… \geq k_1 \times sum_n + k_2 \times sum_{n-1} + …… + k_n \times sum_1 \]

,即由排序不等式,成立

排序不等式证明:
80cb39dbb6fd5266d016f1ef6051802bd40734fa98bc.png

Solution

那么我们可以对每一条边进行考虑,在这一条边左边的点(包括边的端点)的数量与右边的点的数量的乘积,即为这条边在路径中被走过的总次数,那么我们可以那一个堆来从大到小记录每条边被走过的次数

\(k\) 的质因子的数量小于等于边数的时候,为了满足使边权为 1 的边的数量最少的要求,直接将质因子从大到小排序后,每次从堆中(按边的遍历次数从大到小)取出一个数与之相乘,则总和即为最大值

\(k\) 的质因子数量大于边数的时候,由上述分析中得到的结论,我们只要把质因子从大到小相乘,然后将乘过的因子删去,直到因子的数量最后剩下 \(n-1\) 个(即边数),则将上述因子的乘积作为其最大边权,其余 \(n-2\) 个因子为其他的边权,按照 \(k\) 小于等于边数的时候处理即可

Code

#include<iostream>
#include<cstdio>
#include<math.h>
#include<queue>
#include<vector>
#include<cstring>
#include<algorithm>
#define ll long long

const ll mod=1e9+7;
const ll maxn=1e5+10;
ll t,n,m,ans;
ll p[maxn],s[maxn];
std::vector<ll> e[maxn]; 
std::priority_queue<ll> q;

inline void cle()
{
	memset(s,0,sizeof(s));
	memset(p,0,sizeof(p));
	ans=0;
	for(int i=1;i<=n;i++) e[i].clear();
	while(q.size())
	{
		q.pop();
	}
}

inline void dfs(ll x,ll f)
{
	s[x]=1;
	for(int i=0;i<e[x].size();i++)
	{
		ll y=e[x][i];
		
		if(y==f) continue;
		
		dfs(y,x);
		
		s[x]+=s[y];
	}
}

inline void ddfs(ll x,ll f)//计算每条边作为经过的次数
{
	for(int i=0;i<e[x].size();i++)
	{
		ll y=e[x][i];
		
		if(y==f) continue;
		
		q.push(s[y]*(n-s[y]));
		
		ddfs(y,x);
	}
}

inline bool cmp(ll a,ll b)
{
	return a>b;
}

int main(void)
{
	scanf("%lld",&t);
	
	while(t--)
	{
		cle();//多组输入,记得清空
		scanf("%lld",&n);
		
		for(int i=1;i<=n-1;i++)
		{
			ll x,y;
			scanf("%lld %lld",&x,&y);
			e[x].push_back(y);
			e[y].push_back(x);
		}
		
		scanf("%lld",&m);
		for(int i=1;i<=m;i++) scanf("%lld",p+i);
		
		dfs(1,0);
		ddfs(1,0);
		
		std::sort(p+1,p+m+1,cmp);
		
		if(m<=(n-1))//因子数小于等于边数
		{
			ll top=1;
			while(q.size())
			{
				if(top<=m) (ans+=q.top()*p[top])%=mod;
				else (ans+=q.top())%=mod;
				
				top++;
				q.pop();
			}
		}
		else//因子数大于边数
		{
			ll top=m-n+2;
			ll sum=1;
			for(int i=1;i<=(m-(n-1))+1;i++)
			{
				(sum*=p[i])%=mod;
			}
			p[top]=sum;
			
			while(q.size())
			{
				(ans+=q.top()*p[top])%=mod;
				q.pop();
				top++;
			}
		}
		
		printf("%lld\n",ans);
	}
	
	return 0;
}

推荐阅读