首页 > 技术文章 > BSOJ6387题解

lmpp 2022-03-07 14:33 原文

算是刷新了我对树上问题的认知

首先第一问随便做一个 \(O(nk)\) 的 DP 就可以草过去,考虑第二问。

我们将问题分为两个部分:走儿子边的答案和走父亲边的答案。最后拼接一下就好了。

\(fd[u][k]\) 是走儿子边且距离不超过 \(k\) 的节点数量,\(fu[u][k]\) 是走父亲边的答案;\(gd[u][k]\) 是走儿子边的拥挤程度,\(gu[u][k]\) 同理。

这几个转移起来相当简单,不再赘述。可以做到 \(O(nk\log n)\)\(O(n(k+\log n))\)

#include<cstdio>
typedef unsigned ui;
const ui M=1e5+5,K=15,mod=1e9+7;
ui n,k,cnt,h[M],f[M],ans[M],fd[M][K],fu[M][K],gd[M][K],gu[M][K];
struct Edge{
	ui v,nx;
}e[M<<1];
inline void Add(const ui&u,const ui&v){
	e[++cnt]=(Edge){v,h[u]};h[u]=cnt;
	e[++cnt]=(Edge){u,h[v]};h[v]=cnt;
}
inline void swap(ui&a,ui&b){
	ui c=a;a=b;b=c;
}
inline ui pow(ui a,ui b){
	ui ans(1);for(;b;b>>=1,a=1ull*a*a%mod)if(b&1)ans=1ull*ans*a%mod;return ans;
}
inline void init(const ui&u){
	for(ui v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^f[u])f[v]=u,init(v);
}
inline void DFS1(const ui&u){
	for(ui i=0;i<=k;++i)fd[u][i]=gd[u][i]=1;
	for(ui v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^f[u]){
		DFS1(v);
		for(ui i=1;i<=k;++i){
			fd[u][i]+=fd[v][i-1];gd[u][i]=1ull*gd[u][i]*gd[v][i-1]%mod;
		}
	}
	for(ui i=0;i<=k;++i)gd[u][i]=1ull*gd[u][i]*fd[u][i]%mod;
}
inline void DFS2(const ui&u){
	static ui t[K],inv[K];inv[1]=1;
	for(ui i=0;i<=k;++i)fu[u][i]=gu[u][i]=1;
	if(u!=1){
		++fu[u][1];
		gu[u][1]=1ull*gu[f[u]][0]*gd[f[u]][0]%mod*fu[u][1]%mod;
		for(ui i=2;i<=k;++i){
			const ui&sz1=fu[f[u]][i-1],&sz2=fd[f[u]][i-1],&sz3=fd[u][i-2];
			fu[u][i]=sz1+sz2-sz3;
			gu[u][i]=1ull*gu[f[u]][i-1]*gd[f[u]][i-1]%mod*(fu[u][i]-1)%mod*fu[u][i]%mod;
			t[i]=1ull*gd[u][i-2]*sz1%mod*sz2%mod;
			inv[i]=1ull*inv[i-1]*t[i]%mod;
		}
		inv[k]=pow(inv[k],mod-2);
		for(ui i=k;i>1;--i)swap(inv[i],inv[i-1]),inv[i]=1ull*inv[i]*inv[i-1]%mod,inv[i-1]=1ull*inv[i-1]*t[i]%mod;
		for(ui i=2;i<=k;++i)gu[u][i]=1ull*gu[u][i]*inv[i]%mod;
	}
	for(ui v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^f[u])DFS2(v);
}
signed main(){
	scanf("%u%u",&n,&k);
	for(ui i=1;i<n;++i){
		ui u,v;scanf("%u%u",&u,&v);
		Add(u,v);
	}
	init(1);DFS1(1);DFS2(1);
	for(ui u=1;u<=n;++u){
		const ui&sz1=fd[u][k],sz2=fu[u][k];
		ans[u]=1ull*gd[u][k]*gu[u][k]%mod*pow(1ull*sz1*sz2%mod,mod-2)%mod*(sz1+sz2-1)%mod;
		printf("%u ",sz1+sz2-1);
	}
	printf("\n");
	for(ui u=1;u<=n;++u)printf("%u ",ans[u]);
}

推荐阅读