首页 > 技术文章 > Problem G. Depth-First Search

zgglj-com 2018-08-09 14:23 原文

PS:官方题解挺详细的,就是自己写有点难度,在DFS以a[1]为根的树中,要向下传递上一层的合法方案数。也就是 v = v * sum * fac[i](v就是上一层的合法方案数),吐槽一下,计数类的问题是真的麻烦,不重不漏。

#include<bits/stdc++.h>
#define ll long long
#define P pair<int, int>
#define pb push_back
#define mp make_pair
#define pp pop_back
#define lson root << 1
#define INF (int)2e9 + 7
#define rson root << 1 | 1
#define LINF (unsigned long long int)1e18
#define sc(x) scanf("%d", &x)
#define pr(x) printf("%d\n", x)
#define mem(arry, in) memset(arry, in, sizeof(arry))
#define PI acos(0.5) * 3
#define EPS 0.00000001
using namespace std;

inline void upd(int&x, int y) { x < y && (x = y); }

const int N = 1000005;
const int mod = 1000000007;

vector<int> G[N];

int powi(int a, int b) {
    int c = 1;
    for (; b; a = 1ll * a * a % mod, b >>= 1) if (b & 1) c = 1ll * c * a % mod;
    return c;
}

inline int mul(int a, int b, int c) {
    return 1ll * a * b % mod * c % mod;
}

inline void init(int n) {
    for (int i = 1; i <= n; ++i) G[i].clear();
}

struct Tree {
    int n;
    vector<int> T;
    void init(int _n) {
        n = _n;
        T.resize(n + 1);
        for (int i = 0; i <= n; ++i) T[i] = 0;
    }
    void add(int pos, int x) {
        for (int i = pos; i <= n; i += i & -i) T[i] += x;
    }
    int sum(int pos) {
        if (pos > n) pos = n;
        int res = 0;
        for (int i = pos; i; i -= i & -i) res += T[i];
        return res;
    }
}bit[N];

int T, n, ans, id, d;
int a[N], fac[N], inv[N], f[N], invf[N];

void Inite() {
    int mx = 1000000;

    fac[0] = 1;
    for (int i = 1; i <= mx; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;

    inv[mx] = powi(fac[mx], mod - 2);
    for (int i = mx - 1; i; --i) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
}

int solve(int u, int v) {
    id++;
    int sum = 1;
    for (auto tp : G[u]) sum = 1ll * sum * f[tp] % mod;
    for (int i = G[u].size() - 1; ~i; --i) {
        int nxt = lower_bound(G[u].begin(), G[u].end(), a[id + 1]) - G[u].begin();
        int cnt = bit[u].sum(nxt);
        ans = (0ll + ans + 1ll * mul(sum, fac[i], v) * cnt % mod) % mod;
        if (nxt == G[u].size() || G[u][nxt] != a[id + 1]) return 1;
        bit[u].add(nxt + 1, -1);
        sum = 1ll * sum * invf[G[u][nxt]] % mod;
        if (solve(G[u][nxt], mul(v, fac[i], sum))) return 1;
    }
    return 0;
}

void DFS(int u, int p) {
    if (p > 0) G[u].erase(find(G[u].begin(), G[u].end(), p));
    f[u] = fac[G[u].size()];
    bit[u].init(G[u].size());
    for (int i = 0; i < G[u].size(); ++i) bit[u].add(i + 1, 1);
    if(!G[u].empty()) {
        sort(G[u].begin(), G[u].end());
        for (auto v : G[u]) {
            DFS(v, u);
            f[u] = 1ll * f[u] * f[v] % mod;
        }
    }
    invf[u] = powi(f[u], mod - 2);
}

int main()
{
    Inite();
    sc(T);
    while(T--) {
        sc(n);
        init(n);
        for (int i = 0; i < n; ++i) sc(a[i]);
        for (int i = 1; i < n; ++i) {
            int u, v;
            sc(u), sc(v);
            G[u].pb(v);
            G[v].pb(u);
        }
        a[n] = 0, d = 1, ans = 0, id = -1;
        for (int i = 1; i <= n; ++i) d = 1ll * d * fac[G[i].size() - 1] % mod;
        for (int i = 1; i < a[0]; ++i) ans = (0ll + ans + 1ll * d * G[i].size() % mod) % mod;
        DFS(a[0], -1);
        solve(a[0], 1);
        printf("%d\n", ans);
    }
    return 0;
}
View Code

 

推荐阅读