https://www.luogu.com.cn/problem/P4383
\(wqs\)二分/树型\(DP\)
可以看到,题目本质上要求的是取\(k+1\)条链,使其边权和最大
先打一个树型\(DP\)(够头疼的了)
\(dp_{i,j,0/1/2}\),\(i\)表示哪一个节点,\(j\)表示已经用了几条链了,\(0/1/2\)代表的是一个节点的度数,注意,当度数为\(1\)时,我们不计算该条链的贡献(因为可能与其他链拼接)
注意,一个点也可以作为一条链,我们直接初始化\(dp_{i,1,2}=0\),表示该点单独取
此外,\(dp_{i,0,0}=0\)(不取,显然),其他位置的权值都为\(-INF\)
\(C++ Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
#define N 600005
#define INF 12345678987654321
using namespace std;
int tot,head[N],nxt[N],d1[N],d2[N],n,k,x,y,z;
ll ans,dp[N >> 1][105][3],kz[105][3];
void add(int x,int y,int z)
{
tot++;
d1[tot]=y,d2[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
void ckmx(ll &x,ll y)
{
if (x<y)
x=y;
}
void dfs(int u,int f)
{
for (int i=head[u];i;i=nxt[i])
{
int v=d1[i];
int cost=d2[i];
if (v==f)
continue;
dfs(v,u);
for (int t=0;t<=k;t++)
for (int q=0;q<3;q++)
kz[t][q]=dp[u][t][q];
for (int t=k;t>=0;t--)
{
for (int q=0;q<=t;q++)
{
ckmx(dp[u][t][0],dp[v][q][0]+kz[t-q][0]);
if (t>q)
ckmx(dp[u][t][0],dp[v][q][1]+kz[t-q-1][0]);
ckmx(dp[u][t][0],dp[v][q][2]+kz[t-q][0]);
ckmx(dp[u][t][1],dp[v][q][0]+kz[t-q][1]);
ckmx(dp[u][t][1],dp[v][q][0]+kz[t-q][0]+cost);
if (t>q)
ckmx(dp[u][t][1],dp[v][q][1]+kz[t-q-1][1]);
ckmx(dp[u][t][1],dp[v][q][1]+kz[t-q][0]+cost);
ckmx(dp[u][t][1],dp[v][q][2]+kz[t-q][1]);
ckmx(dp[u][t][2],dp[v][q][0]+kz[t-q][2]);
if (t>q)
ckmx(dp[u][t][2],dp[v][q][0]+kz[t-q-1][1]+cost);
if (t>q)
ckmx(dp[u][t][2],dp[v][q][1]+kz[t-q-1][2]);
if (t>q)
ckmx(dp[u][t][2],dp[v][q][1]+kz[t-q-1][1]+cost);
ckmx(dp[u][t][2],dp[v][q][2]+kz[t-q][2]);
}
}
}
}
int main()
{
scanf("%d%d",&n,&k),k++;
for (int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
for (int i=0;i<=n;i++)
{
for (int j=0;j<=k;j++)
for (int t=0;t<3;t++)
dp[i][j][t]=-INF;
dp[i][0][0]=dp[i][1][2]=0;
}
dfs(1,0);
ans=max(dp[1][k][0],max(dp[1][k-1][1],dp[1][k][2]));
printf("%lld\n",ans);
return 0;
}
加一发\(wqs\)
为什么呢,感性理解一下,随着\(k\)的增大,取值应该先上升后下降(链少的时候多取一点优,链多的时候无法整条整条取,答案开始下降)
对于\(k\)上升的部分,多取一条链,答案的增加值应该是下降的(因为一开始可以取很长的,后来只能取较短的)
对于\(k\)下降的部分,多取一条链,答案的减小值应该是上升的(因为不得不拆开更多的链,而拆的时候肯定是先拆损失较小的,后拆损失较大的)
由于可能二分到直线与凸壳中的线段相切,那么为了防止误判,我们在\(dp\)时记录满足最大值的,需要的最少链数。如果我们找到了答案,即可直接返回,如果没有,那么我们可以记录最右的,选取的链数小于\(k\)的那个点,虽然那个点不是我们所需要的点,但是它的斜率和我们所求点是一致的,可以直接代入。
\(dp\)方程写的\(too \quad low\),调了老半天\(QAQ\)
\(C++ Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
#define N 300005
#define INF 10000000000000000
using namespace std;
const int iINF=1000000007;
int tot,head[N],nxt[N << 1],d1[N << 1],d2[N << 1],n,k,x,y,z,ch[N][3];
ll dp[N][3];
void add(int x,int y,int z)
{
tot++;
d1[tot]=y,d2[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
void ckmx(ll &x,ll y)
{
if (y>x)
x=y;
}
ll Tmx(ll x,ll y,ll z)
{
ll mx=x;
ckmx(mx,y);
ckmx(mx,z);
return mx;
}
ll Fimx(ll a,ll b,ll c,ll d,ll e)
{
ll mx=a;
ckmx(mx,b);
ckmx(mx,c);
ckmx(mx,d);
ckmx(mx,e);
return mx;
}
void dfs(int u,int f,ll mid)
{
for (int i=head[u];i;i=nxt[i])
{
int v=d1[i];
int cost=d2[i];
if (v==f)
continue;
dfs(v,u,mid);
ll k0=dp[u][0],k1=dp[u][1],k2=dp[u][2];
int c0=ch[u][0],c1=ch[u][1],c2=ch[u][2];
ll g=Tmx(k0+dp[v][0],k0+dp[v][1]-mid,k0+dp[v][2]);
if (g>dp[u][0])
{
ch[u][0]=iINF;
if (g==k0+dp[v][0])
ch[u][0]=c0+ch[v][0];
if (g==k0+dp[v][1]-mid && c0+ch[v][1]+1<ch[u][0])
ch[u][0]=c0+ch[v][1]+1;
if (g==k0+dp[v][2] && c0+ch[v][2]<ch[u][0])
ch[u][0]=c0+ch[v][2];
dp[u][0]=g;
}
g=Fimx(k0+dp[v][0]+cost,k1+dp[v][0],k1+dp[v][1]-mid,k0+dp[v][1]+cost,k1+dp[v][2]);
if (g>dp[u][1])
{
ch[u][1]=iINF;
if (g==k0+dp[v][0]+cost)
ch[u][1]=c0+ch[v][0];
if (g==k1+dp[v][0] && c1+ch[v][0]<ch[u][1])
ch[u][1]=c1+ch[v][0];
if (g==k1+dp[v][1]-mid && c1+ch[v][1]+1<ch[u][1])
ch[u][1]=c1+ch[v][1]+1;
if (g==k0+dp[v][1]+cost && c0+ch[v][1]<ch[u][1])
ch[u][1]=c0+ch[v][1];
if (g==k1+dp[v][2] && c1+ch[v][2]<ch[u][1])
ch[u][1]=c1+ch[v][2];
dp[u][1]=g;
}
g=Fimx(k2+dp[v][0],k1+dp[v][0]-mid+cost,k2+dp[v][1]-mid,k1+dp[v][1]-mid+cost,k2+dp[v][2]);
if (g>dp[u][2])
{
ch[u][2]=iINF;
if (g==k2+dp[v][0])
ch[u][2]=c2+ch[v][0];
if (g==k1+dp[v][0]-mid+cost && c1+ch[v][0]+1<ch[u][2])
ch[u][2]=c1+ch[v][0]+1;
if (g==k2+dp[v][1]-mid && c2+ch[v][1]+1<ch[u][2])
ch[u][2]=c2+ch[v][1]+1;
if (g==k1+dp[v][1]-mid+cost && c1+ch[v][1]+1<ch[u][2])
ch[u][2]=c1+ch[v][1]+1;
if (g==k2+dp[v][2] && c2+ch[v][2]<ch[u][2])
ch[u][2]=c2+ch[v][2];
dp[u][2]=g;
}
}
}
ll ans2;
int ans1;
void check(ll mid)
{
for (int i=1;i<=n;i++)
{
dp[i][0]=0;
ch[i][0]=0;
dp[i][1]=-INF;
ch[i][1]=-10000000;
dp[i][2]=-mid;
ch[i][2]=1;
}
dfs(1,0,mid);
ll g=Tmx(dp[1][0],dp[1][1]-mid,dp[1][2]);
ans1=iINF;
if (g==dp[1][0])
{
ans1=ch[1][0];
ans2=dp[1][0];
}
if (g==dp[1][1]-mid && ch[1][1]+1<ans1)
{
ans1=ch[1][1]+1;
ans2=dp[1][1]-mid;
}
if (g==dp[1][2] && ch[1][2]<ans1)
{
ans1=ch[1][2];
ans2=dp[1][2];
}
}
int main()
{
scanf("%d%d",&n,&k),k++;
ll l=0,r=0;
for (int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
r+=abs(z);
}
r++;
l=-r;
ll ansl=-INF;
while (l<=r)
{
ll mid=(l+r) >> 1;
check(mid);
if (ans1==k)
{
ansl=mid;
break;
}
if (ans1>k)
l=mid+1; else
ansl=mid,r=mid-1;
}
check(ansl);
printf("%lld\n",ans2+ansl*k);
return 0;
}