题意:
给定一棵有
n
n
n
个结点的树,树边带权,求任意两点间距离中模
3
3
3
分别为
0
,
1
,
2
0,~1,~2
0
,
1
,
2
的路径长度和。(
n
≤
1
e
4
,
∑
n
≤
1
e
5
n\leq 1e4,~\sum n \leq 1e5
n
≤
1
e
4
,
∑
n
≤
1
e
5
)
链接:
https://nanti.jisuanke.com/t/41403
解题思路:
点分治
o
r
or
o
r
树形dp。
参考代码:
点分治:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
#define pb push_back
#define sz(a) ((int)a.size())
typedef long long ll;
typedef pair<int, int> pii;
const int maxn = 1e4 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
vector<pii> G[maxn];
int vis[maxn], siz[maxn];
ll ans[3], sum[3], num[3];
int n, tn, rt, rmn;
void getRt(int u, int f){
int mx = 0; siz[u] = 1;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(vis[v] || v == f) continue;
getRt(v, u);
siz[u] += siz[v];
mx = max(mx, siz[v]);
}
mx = max(mx, tn - siz[u]);
if(mx < rmn) rt = u, rmn = mx;
}
void dfs(int u, int f, ll len){
++num[len % 3];
sum[len % 3] = (sum[len % 3] + len) % mod;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(vis[v] || v == f) continue;
dfs(v, u, len + w);
}
}
void cal(int rt, int len, int opt){
for(int i = 0; i < 3; ++i) num[i] = sum[i] = 0;
dfs(rt, 0, len);
for(int i = 0; i < 3; ++i){
for(int j = 0; j < 3; ++j){
ans[(i + j) % 3] = ((ans[(i + j) % 3] + opt * (sum[i] * num[j] % mod + sum[j] * num[i] % mod)) % mod + mod) % mod;
}
}
}
void dfz(int u){
vis[u] = 1;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(vis[v]) continue;
cal(v, w, -1);
tn = siz[v], rmn = inf, getRt(v, 0);
dfz(rt);
}
vis[u] = 0;
cal(u, 0, 1);
}
int main(){
while(scanf("%d", &n) != EOF){
for(int i = 1; i <= n; ++i) G[i].clear();
for(int i = 1; i < n; ++i){
int u, v, w; scanf("%d%d%d", &u, &v, &w); ++u, ++v;
G[u].pb({w, v}), G[v].pb({w, u});
}
for(int i = 0; i < 3; ++i) ans[i] = 0;
tn = n, rmn = inf, getRt(1, 0);
dfz(1);
for(int i = 0; i < 3; ++i) printf("%lld%c", ans[i], i == 2 ? '\n' : ' ');
}
return 0;
}
树形dp:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
#define pb push_back
#define sz(a) ((int)a.size())
typedef long long ll;
typedef pair<int, int> pii;
const int maxn = 1e4 + 5;
const int mod = 1e9 + 7;
vector<pii> G[maxn];
ll dp[maxn][3], fp[maxn][3], ans[maxn][3], fans[maxn][3];
int n;
void dfs1(int u, int f){
dp[u][0] = 1, dp[u][1] = dp[u][2] = 0;
ans[u][0] = ans[u][1] = ans[u][2] = 0;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(v == f) continue;
dfs1(v, u);
for(int j = 0; j < 3; ++j){
dp[u][(w + j) % 3] += dp[v][j];
ans[u][(w + j) % 3] = (ans[u][(w + j) % 3] + ans[v][j] + dp[v][j] * w) % mod;
}
}
}
void dfs2(int u, int f){
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i].second, w = G[u][i].first;
if(v == f) continue;
for(int j = 0; j < 3; ++j){
fp[v][(w + j) % 3] = fp[u][j] + dp[u][j] - dp[v][((j - w) % 3 + 3) % 3];
fans[v][(w + j) % 3] = ((fans[u][j] + ans[u][j] - ans[v][((j - w) % 3 + 3) % 3] - dp[v][((j - w) % 3 + 3) % 3] * w + fp[v][(w + j) % 3] * w) % mod + mod) % mod;
}
dfs2(v, u);
}
}
int main(){
while(scanf("%d", &n) != EOF){
for(int i = 1; i <= n; ++i) G[i].clear();
for(int i = 1; i < n; ++i){
int u, v, w; scanf("%d%d%d", &u, &v, &w); ++u, ++v;
G[u].pb({w, v}), G[v].pb({w, u});
}
dfs1(1, 0), dfs2(1, 0);
ll ret[3] = {};
for(int i = 1; i <= n; ++i){
for(int j = 0; j < 3; ++j) ret[j] = (ret[j] + ans[i][j] + fans[i][j]) % mod;
}
printf("%lld %lld %lld\n", ret[0], ret[1], ret[2]);
}
return 0;
}