2019ICPC沈阳网络赛 – D. Fish eating fruit

  • Post author:
  • Post category:其他




题意:

给定一棵有



n

n






n





个结点的树,树边带权,求任意两点间距离中模



3

3






3





分别为



0

,

 

1

,

 

2

0,~1,~2






0


,






1


,






2





的路径长度和。(



n

1

e

4

,

 

n

1

e

5

n\leq 1e4,~\sum n \leq 1e5






n













1


e


4


,











n













1


e


5







链接:


https://nanti.jisuanke.com/t/41403



解题思路:

点分治



o

r

or






o


r





树形dp。



参考代码:

点分治:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>

using namespace std;
#define pb push_back
#define sz(a) ((int)a.size())
typedef long long ll;
typedef pair<int, int> pii;
const int maxn = 1e4 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;

vector<pii> G[maxn];
int vis[maxn], siz[maxn];
ll ans[3], sum[3], num[3];
int n, tn, rt, rmn;

void getRt(int u, int f){

    int mx = 0; siz[u] = 1;
    for(int i = 0; i < sz(G[u]); ++i){

        int v = G[u][i].second, w = G[u][i].first;
        if(vis[v] || v == f) continue;
        getRt(v, u);
        siz[u] += siz[v];
        mx = max(mx, siz[v]);
    }
    mx = max(mx, tn - siz[u]);
    if(mx < rmn) rt = u, rmn = mx;
}

void dfs(int u, int f, ll len){

    ++num[len % 3];
    sum[len % 3] = (sum[len % 3] + len) % mod;
    for(int i = 0; i < sz(G[u]); ++i){

        int v = G[u][i].second, w = G[u][i].first;
        if(vis[v] || v == f) continue;
        dfs(v, u, len + w);
    }
}

void cal(int rt, int len, int opt){

    for(int i = 0; i < 3; ++i) num[i] = sum[i] = 0;
    dfs(rt, 0, len);
    for(int i = 0; i < 3; ++i){

        for(int j = 0; j < 3; ++j){

            ans[(i + j) % 3] = ((ans[(i + j) % 3] + opt * (sum[i] * num[j] % mod + sum[j] * num[i] % mod)) % mod + mod) % mod;
        }
    }
}

void dfz(int u){

    vis[u] = 1;
    for(int i = 0; i < sz(G[u]); ++i){

        int v = G[u][i].second, w = G[u][i].first;
        if(vis[v]) continue;
        cal(v, w, -1);
        tn = siz[v], rmn = inf, getRt(v, 0);
        dfz(rt);
    }
    vis[u] = 0;
    cal(u, 0, 1);
}

int main(){

    while(scanf("%d", &n) != EOF){

        for(int i = 1; i <= n; ++i) G[i].clear();
        for(int i = 1; i < n; ++i){

            int u, v, w; scanf("%d%d%d", &u, &v, &w); ++u, ++v;
            G[u].pb({w, v}), G[v].pb({w, u});
        }
        for(int i = 0; i < 3; ++i) ans[i] = 0;
        tn = n, rmn = inf, getRt(1, 0);
        dfz(1);
        for(int i = 0; i < 3; ++i) printf("%lld%c", ans[i], i == 2 ? '\n' : ' ');
    }
    return 0;
}


树形dp:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
#define pb push_back
#define sz(a) ((int)a.size())
typedef long long ll;
typedef pair<int, int> pii;
const int maxn = 1e4 + 5;
const int mod = 1e9 + 7;
vector<pii> G[maxn];
ll dp[maxn][3], fp[maxn][3], ans[maxn][3], fans[maxn][3];
int n;

void dfs1(int u, int f){

    dp[u][0] = 1, dp[u][1] = dp[u][2] = 0;
    ans[u][0] = ans[u][1] = ans[u][2] = 0;
    for(int i = 0; i < sz(G[u]); ++i){

        int v = G[u][i].second, w = G[u][i].first;
        if(v == f) continue;
        dfs1(v, u);
        for(int j = 0; j < 3; ++j){

            dp[u][(w + j) % 3] += dp[v][j];
            ans[u][(w + j) % 3] = (ans[u][(w + j) % 3] + ans[v][j] + dp[v][j] * w) % mod;
        }
    }
}

void dfs2(int u, int f){

    for(int i = 0; i < sz(G[u]); ++i){

        int v = G[u][i].second, w = G[u][i].first;
        if(v == f) continue;
        for(int j = 0; j < 3; ++j){

            fp[v][(w + j) % 3] = fp[u][j] + dp[u][j] - dp[v][((j - w) % 3 + 3) % 3];
            fans[v][(w + j) % 3] = ((fans[u][j] + ans[u][j] - ans[v][((j - w) % 3 + 3) % 3] - dp[v][((j - w) % 3 + 3) % 3] * w + fp[v][(w + j) % 3] * w) % mod + mod) % mod;
        }
        dfs2(v, u);
    }
}

int main(){

    while(scanf("%d", &n) != EOF){

        for(int i = 1; i <= n; ++i) G[i].clear();
        for(int i = 1; i < n; ++i){

            int u, v, w; scanf("%d%d%d", &u, &v, &w); ++u, ++v;
            G[u].pb({w, v}), G[v].pb({w, u});
        }
        dfs1(1, 0), dfs2(1, 0);
        ll ret[3] = {};
        for(int i = 1; i <= n; ++i){

            for(int j = 0; j < 3; ++j) ret[j] = (ret[j] + ans[i][j] + fans[i][j]) % mod;
        }
        printf("%lld %lld %lld\n", ret[0], ret[1], ret[2]);
    }
    return 0;
}



版权声明:本文为weixin_44059127原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。