虚树入门学习

  • Post author:
  • Post category:其他




题目


link



题意

给你 点集使得其与根节点不相连的最小切割边权和

虚树 就是使得我们处理的结点进行简化 我们可以得到有效信息 从而降低复杂度

本题中 如果我们要处理七号和十号节点 显然 2号点以及子树都是多余的

请添加图片描述

我们可以得到简化的 树

请添加图片描述

我们对这颗树进行树形dp 显然容易

那么我们如何构建虚树呢 我们可以通过dfs 序和一个栈来维护、

具体细节见

如何构建虚树

直接在栈上dp

因为 栈 维护的是从下到上的dfs序

所以也可以

显然 栈上dp 是最优的

#include <bits/stdc++.h>
using namespace  std;
#define  int long long
//typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define fi first
#define se second
#define pb  push_back
#define inf 1e18
#define endl '\n'
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define de_bug(x) cerr << #x << "=" << x << endl
#define all(a) a.begin(),a.end()
#define IOS   std::ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define  fer(i,a,b)  for(int i=a;i<=b;i++)
#define  der(i,a,b)  for(int i=a;i>=b;i--)
const int mod = 1e9 + 7;
const int N = 1e6 + 10;
int n, k;
int f[N],tot, son[N],h[N],e[N],ne[N],idx,w[N],d[N],cnt,siz[N],top[N],mi[N],dfn[N],sum[N] ,t,stk[N],q ,a[N];
void add(int a, int b, int c) {
	e[cnt] = b, w[cnt] = c, ne[cnt] = h[a], h[a] = cnt++;
}
void dfs(int u, int fa) {
	d[u] = d[fa] + 1;
	f[u] = fa;
	dfn[u] = ++tot;
	siz[u] = 1;
	int ma = -1;
	for(int i = h[u]; ~i; i = ne[i]) {
		int v = e[i];
		if(v == fa)continue;
		mi[v] = min(mi[u], w[i]);
		dfs(v, u);
		siz[u] += siz[v];
		if(ma < siz[v]) {
			son[u] = v, ma = siz[v];
		}
	}
}
void dfs2(int u, int ff) {
	top[u] = ff;
	if(!son[u])return;
	dfs2(son[u], ff);
	for(int i = h[u]; ~i; i = ne[i]) {
		int v = e[i];
		if(v != f[u] && v != son[u])dfs2(v, v);
	}
}
int lca(int x, int y) {
	while(top[x] != top[y]) {
		if(d[top[x]] < d[top[y]])swap(x, y);
		x = f[top[x]];
	}
	return d[x] < d[y] ? x : y;
}
void  sol(int x, int y) {
	sum[x] += min(sum[y], mi[y]);
}
void  build(int x) {
	if(t == 1) {
		sum[x] = mi[x];
		stk[++t] = x;
		return ;
	}
	int tmp = lca(stk[t], x);
	if(tmp == stk[t]) {
		sum[x] = mi[x];
		stk[++t] = x;
		return ;
	}
	while(t > 1 && dfn[tmp] <= dfn[stk[t - 1]]) {
		sol(stk[t - 1], stk[t]);t--;
	}
	while(tmp != stk[t]) {
		sum[tmp]=0;
		sol(tmp, stk[t]);
		stk[t] = tmp;
	}
	stk[++t] = x;
	sum[x] = mi[x];
	return ;

}
bool cmp(int x,int y){
	return dfn[x]<dfn[y];
}
void solve() {
	cin >> n;
	memset(h,-1,sizeof(h));
	fer(i, 1, n - 1) {
		int a, b, c;
		cin >> a >> b >> c;
		add(a, b, c);
		add(b, a, c);
	}
	mi[1] = 1ll << 60;
	dfs(1, -1);
	dfs2(1, 1);
	cin >> q;
	while(q--) {
		sum[1] = 0;
		stk[t = 1] = 1;
		int x;
		cin >> x;
		fer(i, 1, x)cin >> a[i];
		sort(a + 1, a + x + 1, cmp);
		fer(i, 1, x) build(a[i]);
		while(t) sol(stk[t - 1], stk[t]),t--;
		cout << sum[1] << endl;
	}
}
signed main() {
	IOS;
	int _ = 1;
	//cin>>_;
	while( _-- )
		solve();
}


建好树后dp

每次记录需要清空的点

否则复杂度会退化

#include <bits/stdc++.h>
using namespace  std;
#define  int long long
//typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define fi first
#define se second
#define pb  push_back
#define inf 1e18
#define endl '\n'
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define de_bug(x) cerr << #x << "=" << x << endl
#define all(a) a.begin(),a.end()
#define IOS   std::ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define  fer(i,a,b)  for(int i=a;i<=b;i++)
#define  der(i,a,b)  for(int i=a;i>=b;i--)
const int mod = 1e9 + 7;
const int N = 1e6 + 10;
int n, k;
vector<int>g[N];
int f[N], tot, son[N], h[N], e[N], ne[N], idx, w[N], d[N], cnt, siz[N];
int ss[N], top[N], mi[N], dfn[N], sum[N] , t, stk[N], q , a[N],id;

void add(int a, int b, int c) {
	e[cnt] = b, w[cnt] = c, ne[cnt] = h[a], h[a] = cnt++;
}
void dfs(int u, int fa) {
	d[u] = d[fa] + 1;
	f[u] = fa;
	dfn[u] = ++tot;
	siz[u] = 1;
	int ma = -1;
	for(int i = h[u]; ~i; i = ne[i]) {
		int v = e[i];
		if(v == fa)continue;
		mi[v] = min(mi[u], w[i]);
		dfs(v, u);
		siz[u] += siz[v];
		if(ma < siz[v]) {
			son[u] = v, ma = siz[v];
		}
	}
}
void dfs2(int u, int ff) {
	top[u] = ff;
	if(!son[u])return;
	dfs2(son[u], ff);
	for(int i = h[u]; ~i; i = ne[i]) {
		int v = e[i];
		if(v != f[u] && v != son[u])dfs2(v, v);
	}
}
int lca(int x, int y) {
	while(top[x] != top[y]) {
		if(d[top[x]] < d[top[y]])swap(x, y);
		x = f[top[x]];
	}
	return d[x] < d[y] ? x : y;
}
void  sol(int x, int y) {
	g[x].push_back(y);
//	g[y].push_back(x);
 //   ss[++id]=x;
   // ss[++id]=y;
	//sum[x]+=min(sum[y],mi[y]);
}

void dfs3(int u){
	for(auto v:g[u]){
//		if(v==fa)continue;
		dfs3(v);
		sum[u]+=min(sum[v],mi[v]);
	}
	g[u].clear();
}
void  build(int x) {
	if(t == 1) {
		sum[x] = mi[x];
		stk[++t] = x;
		return ;
	}
	int tmp = lca(stk[t], x);
	if(tmp == stk[t]) {
		sum[x] = mi[x];
		stk[++t] = x;
		return ;
	}
	while(t > 1 && dfn[tmp] <= dfn[stk[t - 1]]) {
		sol(stk[t - 1], stk[t]);
		t--;
	}
	while(tmp != stk[t]) {
		sum[tmp] = 0;
		sol(tmp, stk[t]);
		stk[t] = tmp;
	}
	stk[++t] = x;
	sum[x] = mi[x];
	return ;

}
bool cmp(int x, int y) {
	return dfn[x] < dfn[y];
}
void solve() {
	cin >> n;
	memset(h, -1, sizeof(h));
	fer(i, 1, n - 1) {
		int a, b, c;
		cin >> a >> b >> c;
		add(a, b, c);
		add(b, a, c);
	}
	mi[1] = 1ll << 60;
	dfs(1, -1);
	dfs2(1, 1);
	cin >> q;
	while(q--) {
		id=0;
		sum[1] = 0;
		stk[t = 1] = 1;
		int x;
		cin >> x;
		fer(i, 1, x)cin >> a[i];
		sort(a + 1, a + x + 1, cmp);
		fer(i, 1, x) build(a[i]);
		while(t) sol(stk[t - 1], stk[t]), t--;
		dfs3(1);
		cout << sum[1] << endl;
	    //for(int i=1;i<=id;i++)
	   // g[ss[i]].clear();
	}
}
signed main() {
	IOS;
	int _ = 1;
	//cin>>_;
	while( _-- )
		solve();
}



版权声明:本文为qq_61305213原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。