首页 > 技术文章 > P3478 [POI2008]STA-Station

luckyblock 2020-10-18 22:08 原文

知识点:换根 DP

原题面:Luogu


换 根 D P 入 门


题意简述

给定一 \(n\) 个结点的树,求出一个节点,使得以该节点为根时,所有节点的深度和最大。
一个节点的深度定义为该节点到根的简单路径上边的数量。
\(1\le n\le 10^6\)
时限 2s,内存 128M,SPJ。


分析题意

需要同时维护子树和祖先的信息,考虑换根 DP。
钦定 1 为根,第一次 dfs 处理子树信息,第二次 dfs 处理祖先信息进行换根。

先考虑子树信息,预处理出所有子树的 \(size\)
\(f_u\) 表示 \(u\) 的子树中所有点到 \(u\) 的距离之和,则显然有:

\[f_{u} = \sum_{fa_v = u}{\left\{f_{v} + size_v\right\}} \]


对于祖先的信息,考虑在第二次 dfs 中维护。
若当前 dfs 到节点 \(u\),设 \(val\) 表示 \(u\) 上面各点到 \(u\) 的距离之和。
考虑向下深入,根由 \(u\) 变为 \(v\) 时,对 \(val\) 的影响。

先考虑哪些点在 \(v\) 上方,包括 \(u\) 上方的点,及 \(v\) 的兄弟。
对于 \(u\) 上方的点,其到 \(u\) 的距离和为 \(val\)
对于 \(v\) 的兄弟们,其到 \(u\) 的距离和为 \(f_{u} - (f_{v} + size+v)\)(除去 \(v\) 的影响)。

对于所有上述节点,其到 \(v\) 的距离等于其到 \(u\) 的距离 \(+1\),则新的 \(val\) 还应 \(+ n-size_v\)


在第二次 dfs 换根时,进行答案的判定。
答案即为 \(f_u + val\) 最大的节点之一。


爆零小技巧


代码实现

自 YY 写法。

//知识点:换根 DP 
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kMaxn = 1e6 + 10;
const int kMaxm = kMaxn << 1;
//=============================================================
int n, ans, e_num, head[kMaxn], v[kMaxm], ne[kMaxm];
LL max_sum, f[kMaxn], size[kMaxn];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
void AddEdge(int u_, int v_) {
  v[++ e_num] = v_;
  ne[e_num] = head[u_];
  head[u_] = e_num;
}
void Dfs1(int u_, int fa_) {
  size[u_] = 1ll;
  for (int i = head[u_]; i; i = ne[i]) {
    int v_ = v[i];
    if (v_ == fa_) continue ;
    Dfs1(v_, u_);
    size[u_] += size[v_];
    f[u_] += f[v_] + size[v_];
  }
}
void Dfs2(int u_, int fa_, LL val_) {
  if (f[u_] + val_ > max_sum) {
    max_sum = f[fa_] + val_;
    ans = u_;
  }
  for (int i = head[u_]; i; i = ne[i]) {
    int v_ = v[i];
    if (v_ == fa_) continue ;
    LL new_val = val_ + f[u_] - f[v_] - size[v_] + (n - size[v_]);
    Dfs2(v_, u_, new_val);
  }
}
//=============================================================
int main() {
  n = read();
  for (int i = 1; i < n; ++ i) {
    int u_ = read(), v_ = read();
    AddEdge(u_, v_), AddEdge(v_, u_);
  }
  Dfs1(1, 0);
  Dfs2(1, 0, 0);
  printf("%d\n", ans);
  return 0;
}

另一种循环利用 \(f\) 的写法。

//知识点:换根 DP 
/*
By:Luckyblock
设 size[i] 表示,i 子树的大小。   
dep[i] 表示,i 的深度。   
f[i] 表示,以 i 为根时,各点与其距离之和。   
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kMaxn = 1e6 + 10;
const int kMaxm = kMaxn << 1;
//=============================================================
int n, ans, e_num, head[kMaxn], v[kMaxm], ne[kMaxm];
LL max_sum, size[kMaxn], f[kMaxn], dep[kMaxn];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir_, int sec_) {
  if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
  if (sec_ < fir_) fir_ = sec_;
}
void AddEdge(int u_, int v_) {
  v[++ e_num] = v_;
  ne[e_num] = head[u_];
  head[u_] = e_num;
}
void Dfs1(int u_, int fa_) {
  size[u_] = 1ll;
  dep[u_] = dep[fa_] + 1ll;
  for (int i = head[u_]; i; i = ne[i]) {
    int v_ = v[i];
    if (v_ == fa_) continue ;
    Dfs1(v_, u_);
    size[u_] += size[v_];
  }
}
void Dfs2(int u_, int fa_) {
  if (f[u_] > max_sum) {
    max_sum = f[u_];
    ans = u_;
  }
  for (int i = head[u_]; i; i = ne[i]) {
    int v_ = v[i];
    if (v_ == fa_) continue ;
    f[v_] = f[u_] - size[v_] + (n - size[v_]);
    Dfs2(v_, u_);
  }
}
//=============================================================
int main() {
  n = read();
  for (int i = 1; i < n; ++ i) {
    int u_ = read(), v_ = read();
    AddEdge(u_, v_), AddEdge(v_, u_);
  }
  Dfs1(1, 0); 
  for (int i = 1; i <= n; ++ i) {
    f[1] = dep[i];
  }
  Dfs2(1, 0);
  printf("%d\n", ans);
  return 0;
}

推荐阅读