首页 > 技术文章 > Luogu4383 [八省联考2018]林克卡特树

GK0328 2020-08-18 18:30 原文

https://www.luogu.com.cn/problem/P4383

\(wqs\)二分/树型\(DP\)

可以看到,题目本质上要求的是取\(k+1\)条链,使其边权和最大

先打一个树型\(DP\)(够头疼的了)

\(dp_{i,j,0/1/2}\),\(i\)表示哪一个节点,\(j\)表示已经用了几条链了,\(0/1/2\)代表的是一个节点的度数,注意,当度数为\(1\)时,我们不计算该条链的贡献(因为可能与其他链拼接)

注意,一个点也可以作为一条链,我们直接初始化\(dp_{i,1,2}=0\),表示该点单独取

此外,\(dp_{i,0,0}=0\)(不取,显然),其他位置的权值都为\(-INF\)

\(C++ Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
#define N 600005
#define INF 12345678987654321
using namespace std;
int tot,head[N],nxt[N],d1[N],d2[N],n,k,x,y,z;
ll ans,dp[N >> 1][105][3],kz[105][3];
void add(int x,int y,int z)
{
    tot++;
    d1[tot]=y,d2[tot]=z;
    nxt[tot]=head[x];
    head[x]=tot;
}
void ckmx(ll &x,ll y)
{
    if (x<y)
        x=y;
}
void dfs(int u,int f)
{
    for (int i=head[u];i;i=nxt[i])
    {
        int v=d1[i];
        int cost=d2[i];
        if (v==f)
            continue;
        dfs(v,u);
        for (int t=0;t<=k;t++)
            for (int q=0;q<3;q++)
                kz[t][q]=dp[u][t][q];
        for (int t=k;t>=0;t--)
        {
            for (int q=0;q<=t;q++)
            {
                ckmx(dp[u][t][0],dp[v][q][0]+kz[t-q][0]);
                if (t>q)
                    ckmx(dp[u][t][0],dp[v][q][1]+kz[t-q-1][0]);
                ckmx(dp[u][t][0],dp[v][q][2]+kz[t-q][0]);
                ckmx(dp[u][t][1],dp[v][q][0]+kz[t-q][1]);
                ckmx(dp[u][t][1],dp[v][q][0]+kz[t-q][0]+cost);
                if (t>q)
                    ckmx(dp[u][t][1],dp[v][q][1]+kz[t-q-1][1]);
                ckmx(dp[u][t][1],dp[v][q][1]+kz[t-q][0]+cost);
                ckmx(dp[u][t][1],dp[v][q][2]+kz[t-q][1]);
                ckmx(dp[u][t][2],dp[v][q][0]+kz[t-q][2]);
                if (t>q)
                    ckmx(dp[u][t][2],dp[v][q][0]+kz[t-q-1][1]+cost);
                if (t>q)
                    ckmx(dp[u][t][2],dp[v][q][1]+kz[t-q-1][2]);
                if (t>q)
                    ckmx(dp[u][t][2],dp[v][q][1]+kz[t-q-1][1]+cost);
                ckmx(dp[u][t][2],dp[v][q][2]+kz[t-q][2]);
            }
        }
    }
}
int main()
{
    scanf("%d%d",&n,&k),k++;
    for (int i=1;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    }
    for (int i=0;i<=n;i++)
    {
        for (int j=0;j<=k;j++)
            for (int t=0;t<3;t++)
                dp[i][j][t]=-INF;
        dp[i][0][0]=dp[i][1][2]=0;
    }
    dfs(1,0);
    ans=max(dp[1][k][0],max(dp[1][k-1][1],dp[1][k][2]));
    printf("%lld\n",ans);
    return 0;
}

加一发\(wqs\)

为什么呢,感性理解一下,随着\(k\)的增大,取值应该先上升后下降(链少的时候多取一点优,链多的时候无法整条整条取,答案开始下降)

对于\(k\)上升的部分,多取一条链,答案的增加值应该是下降的(因为一开始可以取很长的,后来只能取较短的)

对于\(k\)下降的部分,多取一条链,答案的减小值应该是上升的(因为不得不拆开更多的链,而拆的时候肯定是先拆损失较小的,后拆损失较大的)

由于可能二分到直线与凸壳中的线段相切,那么为了防止误判,我们在\(dp\)时记录满足最大值的,需要的最少链数。如果我们找到了答案,即可直接返回,如果没有,那么我们可以记录最右的,选取的链数小于\(k\)的那个点,虽然那个点不是我们所需要的点,但是它的斜率和我们所求点是一致的,可以直接代入。

\(dp\)方程写的\(too \quad low\),调了老半天\(QAQ\)

\(C++ Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
#define N 300005
#define INF 10000000000000000
using namespace std;
const int iINF=1000000007;
int tot,head[N],nxt[N << 1],d1[N << 1],d2[N << 1],n,k,x,y,z,ch[N][3];
ll dp[N][3];
void add(int x,int y,int z)
{
    tot++;
    d1[tot]=y,d2[tot]=z;
    nxt[tot]=head[x];
    head[x]=tot;
}
void ckmx(ll &x,ll y)
{
    if (y>x)
        x=y;
}
ll Tmx(ll x,ll y,ll z)
{
    ll mx=x;
    ckmx(mx,y);
    ckmx(mx,z);
    return mx;
}
ll Fimx(ll a,ll b,ll c,ll d,ll e)
{
    ll mx=a;
    ckmx(mx,b);
    ckmx(mx,c);
    ckmx(mx,d);
    ckmx(mx,e);
    return mx;
}
void dfs(int u,int f,ll mid)
{
    for (int i=head[u];i;i=nxt[i])
    {
        int v=d1[i];
        int cost=d2[i];
        if (v==f)
            continue;
        dfs(v,u,mid);
        ll k0=dp[u][0],k1=dp[u][1],k2=dp[u][2];
        int c0=ch[u][0],c1=ch[u][1],c2=ch[u][2];
        ll g=Tmx(k0+dp[v][0],k0+dp[v][1]-mid,k0+dp[v][2]);
        if (g>dp[u][0])
        {
            ch[u][0]=iINF;
            if (g==k0+dp[v][0])
                ch[u][0]=c0+ch[v][0];
            if (g==k0+dp[v][1]-mid && c0+ch[v][1]+1<ch[u][0])
                ch[u][0]=c0+ch[v][1]+1;
            if (g==k0+dp[v][2] && c0+ch[v][2]<ch[u][0])
                ch[u][0]=c0+ch[v][2];
            dp[u][0]=g;
        }
        g=Fimx(k0+dp[v][0]+cost,k1+dp[v][0],k1+dp[v][1]-mid,k0+dp[v][1]+cost,k1+dp[v][2]);
        if (g>dp[u][1])
        {
            ch[u][1]=iINF;
            if (g==k0+dp[v][0]+cost)
                ch[u][1]=c0+ch[v][0];
            if (g==k1+dp[v][0] && c1+ch[v][0]<ch[u][1])
                ch[u][1]=c1+ch[v][0];
            if (g==k1+dp[v][1]-mid && c1+ch[v][1]+1<ch[u][1])
                ch[u][1]=c1+ch[v][1]+1;
            if (g==k0+dp[v][1]+cost && c0+ch[v][1]<ch[u][1])
                ch[u][1]=c0+ch[v][1];
            if (g==k1+dp[v][2] && c1+ch[v][2]<ch[u][1])
                ch[u][1]=c1+ch[v][2];
            dp[u][1]=g;
        }
        g=Fimx(k2+dp[v][0],k1+dp[v][0]-mid+cost,k2+dp[v][1]-mid,k1+dp[v][1]-mid+cost,k2+dp[v][2]);
        if (g>dp[u][2])
        {
            ch[u][2]=iINF;
            if (g==k2+dp[v][0])
                ch[u][2]=c2+ch[v][0];
            if (g==k1+dp[v][0]-mid+cost && c1+ch[v][0]+1<ch[u][2])
                ch[u][2]=c1+ch[v][0]+1;
            if (g==k2+dp[v][1]-mid && c2+ch[v][1]+1<ch[u][2])
                ch[u][2]=c2+ch[v][1]+1;
            if (g==k1+dp[v][1]-mid+cost && c1+ch[v][1]+1<ch[u][2])
                ch[u][2]=c1+ch[v][1]+1;
            if (g==k2+dp[v][2] && c2+ch[v][2]<ch[u][2])
                ch[u][2]=c2+ch[v][2];
            dp[u][2]=g;      
        }
    }
}
ll ans2;
int ans1;
void check(ll mid)
{
    for (int i=1;i<=n;i++)
    {
        dp[i][0]=0;
        ch[i][0]=0;
        dp[i][1]=-INF;
        ch[i][1]=-10000000;
        dp[i][2]=-mid;
        ch[i][2]=1;
    }
    dfs(1,0,mid);
    ll g=Tmx(dp[1][0],dp[1][1]-mid,dp[1][2]);
    ans1=iINF;
    if (g==dp[1][0])
    {
        ans1=ch[1][0];
        ans2=dp[1][0];
    }
    if (g==dp[1][1]-mid && ch[1][1]+1<ans1)
    {
        ans1=ch[1][1]+1;
        ans2=dp[1][1]-mid;
    }
    if (g==dp[1][2] && ch[1][2]<ans1)
    {
        ans1=ch[1][2];
        ans2=dp[1][2];
    }
}
int main()
{
    scanf("%d%d",&n,&k),k++;
    ll l=0,r=0;
    for (int i=1;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
        r+=abs(z);
    }
    r++;
    l=-r;
    ll ansl=-INF;
    while (l<=r)
    {
        ll mid=(l+r) >> 1;
        check(mid);
        if (ans1==k)
        {
            ansl=mid;
            break;
        }
        if (ans1>k)
            l=mid+1; else
            ansl=mid,r=mid-1;
    }
    check(ansl);
    printf("%lld\n",ans2+ansl*k);
    return 0;
}

推荐阅读