虚树小结

虚树

据说虚树的英文是叫 Auxiliary tree,简单的来说是维护树上仅于题目有关的信息,除去掉冗余信息以保证时间复杂度的一种算法。为什么说是虚树呢,是因为他是在原来树的基础上建出的一颗全新的树,需要重新建边或分配点权等信息。

虚树的题目通常是这样的,给出一颗大小为 n 有边权或点权的树,然后有 q 次询问,每次给出 $k_i$ 个关键点,要求计算关于这 k 个点的问题(通常和 LCA 有关),但是 $\sum k_i$ 又比较小(通常和树的大小同阶)。

假设你可以 $O(n)$ 的在树上解决单次的询问,那么总时间复杂度是 $O(nq)$ 的,问题的 n 和 q 一般为 $10^5$ 显然这样是不可以过这道题的。这时候就需要用到虚树了,使用虚树我们对于每个询问都可以 $O(k_i)$ 的建出一颗有 $O(k_i)$ 个节点的树,并且可以通过这棵树 $O(k_i)$ 的计算出答案。

BZOJ 2886

给出一颗 n 个节点的树,每条边有正边权,有 q 次询问,每个询问包含 $k_i$ 个选中的点,要求选出一些边使得去掉这 k 条边后 1 号节点不与这 k 个点中的任一个连通。

$ 1 \le n \le 250000, \sum k_i \le 5 \times10^5$

分析

对于每次询问,我们首先考虑 $O(n)$ 的做法,我设 $dp[u]$ 是使得 u 节点及其子树中所有选中的点与 u 的父亲断开的最小花费,我们考虑接下来两种转移;

如果 u 号节点是选中的点的话,那么肯定要断开 u 号点到父亲的边,并且为了让花费最少 u 的子树中不需要断开其他边了。

如果 u 号节点不是选中的点,那么如果我们不断开从 u 到其父亲的边的话,那么我们需要保证 u 的子树中所有的选中的点不与 u 连通,那么 $dp[u] = \sum dp[v]$,v 是 u 的所有的儿子。

虚树优化

接下来我们考虑这个 DP,对于树上一条没有分叉的链,我们如果要断开这条链的话我们肯定要断开这条链上的边权最小的边。仔细分析的话会发现,对于这个问题的每次询问,其实只和选中的 $k_i$ 个点及他们之间两两的 LCA 有关。

思考到这一步,并且会虚树的话,这个题已经解决了。虚树可以对某些关键点求出其两两的 LCA,并且有证明给出,k 个点的两两求 LCA,得出的不同的点的数量是 $O(k)$ 的。

虚树构造

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
//储存关键的点及其 dfs 序
vector<pair<int, int> > imp;
//虚树上的边
vector<int> G[MAX];
//构建虚树时用到的栈
int s[MAX], s_top;

void build() {
//按照 dfs 序排序
sort(imp.begin(), imp.end());
//为了方便处理,我们把树根固定当做关键点
s[s_top = 1] = 1;

//按照 dfs 序遍历每个点
for (int i = 0; i < imp.size(); i++) {
//求出当前节点与栈顶节点的 LCA
int v = imp[i].second;
int t = lca(v, s[s_top]);
// t 的 dfs 序肯定小于等于栈顶元素的 (因为他是 v 和栈顶元素的 LCA)
while (dfn[t] < dfn[s[s_top]]) {
//直到栈顶的元素的前一个元素的 dfs 序小于当前的 LCA,停止循环,如果当前的 LCA 不在栈中,那么入栈
if (dfn[t] >= dfn[s[s_top - 1]]) {
G[t].push_back(s[s_top]);
if (s[--s_top] != t) s[++s_top] = t;
break;
}
//进入循环以为着栈顶元素 dfs 小于当前的 LCA,并且我们是按照 dfs 序遍历的,那么栈顶元素的子树肯定已经构建完毕,需要出栈并对当前的 LCA 建边
G[s[s_top - 1]].push_back(s[s_top]);
s_top--;
}
s[++s_top] = v;
}
//最后对栈中的元素依次建边
while (s_top--) G[s[s_top]].push_back(s[s_top + 1]);
}

题目代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int MAX = 250000 + 100;

struct Edge {
int u, v, w, nxt;
} edge[MAX * 2];

int head[MAX], etot;
int n;

void add_edge(int u, int v, int w) {
edge[etot].u = u;
edge[etot].v = v;
edge[etot].w = w;
edge[etot].nxt = head[u];
head[u] = etot++;
}

vector<pair<int, int> > imp;

int dfn[MAX], clo;
int f[MAX][20], dep[MAX];
ll pre[MAX][20];

void dfs(int u, int fa, int depth) {
dfn[u] = ++clo;
f[u][0] = fa;
dep[u] = depth;
for (int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v;
if (dfn[v]) continue;
pre[v][0] = edge[i].w;
dfs(v, u, depth + 1);
}
}

int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
int diff = dep[u] - dep[v];
for (int i = 0; i < 20; i++) if (diff >> i & 1) u = f[u][i];
if (u == v) return u;
for (int i = 19; i >= 0; i--) {
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
}
return f[u][0];
}

ll get_pre(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
int diff = dep[u] - dep[v];
ll ans = 1e18;
for (int i = 0; i < 20; i++)
if (diff >> i & 1) {
ans = min(ans, pre[u][i]);
u = f[u][i];
}
return ans;
}

void init() {
for (int i = 1; i < 20; i++) {
for (int j = 1; j <= n; j++) {
f[j][i] = f[f[j][i - 1]][i - 1];
pre[j][i] = min(pre[f[j][i - 1]][i - 1], pre[j][i - 1]);
}
}
}

bool ojbk[MAX];

vector<int> G[MAX];
vector<int> occur;

int s[MAX], s_top;

void build() {
sort(imp.begin(), imp.end());

s[s_top = 1] = 1;
occur.push_back(1);

for (int i = 0; i < imp.size(); i++) {
int v = imp[i].second;
int t = lca(v, s[s_top]);
occur.push_back(v);
occur.push_back(t);

while (dfn[t] < dfn[s[s_top]]) {
if (dfn[t] >= dfn[s[s_top - 1]]) {
G[t].push_back(s[s_top]);
if (s[--s_top] != t) s[++s_top] = t;
break;
}
G[s[s_top - 1]].push_back(s[s_top]);
s_top--;
}
s[++s_top] = v;
}
while (s_top--) G[s[s_top]].push_back(s[s_top + 1]);
}

ll dp[MAX];

void dfs(int u, ll pre) {
dp[u] = pre;
ll sum = 0;
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
dfs(v, get_pre(u, v));
sum += dp[v];
}
if (!ojbk[u]) dp[u] = min(dp[u], sum);
}

void magic() {
build();
dfs(1, 1e18);
for (int i = 0; i < occur.size(); i++) {
G[occur[i]].clear();
ojbk[occur[i]] = false;
}
occur.clear();
}


int main() {
memset(pre, 0x3f, sizeof(pre));
memset(head, -1, sizeof(head));
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v, w;
scanf("%d %d %d", &u, &v, &w);
add_edge(u, v, w);
add_edge(v, u, w);
}
dfs(1, 0, 0);
init();
int m;
scanf("%d", &m);
while (m--) {
imp.clear();
int k;
scanf("%d", &k);
for (int i = 1; i <= k; i++) {
int t;
scanf("%d", &t);
ojbk[t] = true;
imp.push_back(make_pair(dfn[t], t));
}
magic();
printf("%lld\n", dp[1]);
}
}