树分治小结

分治

首先树分治是一种在树上的分治算法,分为点分治和边分治,边分治不常见并且不能保证复杂度,所以没学。那么对于点分治,核心实现是每次选取一个点当做根节点,然后解决根节点的问题,再去递归到每个子树去解决子树的问题。

分治的一个核心在于根据树的重心去划分递归子树,重心是树上的一个点,删去这个点后剩下的最大的子树最小。可以保证删去这个点后剩下的所有子树的大小不会超过原来的一半。由此保证递归深度至多 $\log n$ 层,如果解决每层的问题需要 $O(n)$,那么总时间复杂度是 $O(n\log n)$。

两道例题

POJ 1741

给出一棵树,每条边有边权,设树上任意两点 $u, v$ ($u\neq v$)的最短距离为 $dist(u,v) $ 问有多少个序偶 $(u, v)$ 使得 $dist(u, v) \le K$。

节点数不超过 $10^4$,边权不超过 $10^3$ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <iostream>

using namespace std;

const int MAX = 112345;

int n, K;
vector<pair<int, int> > edge[MAX];

int size, root, ans;
int num[MAX], max_son[MAX], depth[MAX];
bool done[MAX];
vector<int> dep;

void get_root(int u, int fa) {
max_son[u] = 0;
num[u] = 1;
for (int i = 0; i < edge[u].size(); i++) {
int v = edge[u][i].first;
if (done[v] || v == fa) continue;
get_root(v, u);
num[u] += num[v];
max_son[u] = max(max_son[u], num[v]);
}
max_son[u] = max(max_son[u], size - num[u]);
if (max_son[u] < max_son[root]) root = u;
}

void get_dep(int u, int fa) {
dep.push_back(depth[u]);
num[u] = 1;
for (int i = 0; i < edge[u].size(); i++) {
int v = edge[u][i].first;
int w = edge[u][i].second;
if (done[v] || v == fa) continue;
depth[v] = depth[u] + w;
get_dep(v, u);
num[u] += num[v];
}
}

int calc(int u, int init) {
dep.clear();
depth[u] = init;
get_dep(u, 0);
sort(dep.begin(), dep.end());
int rst = 0;
for (int l = 0, r = dep.size() - 1; l < r;) {
if (dep[l] + dep[r] <= K) rst += r - l++;
else r--;
}
return rst;
}

void solve(int u) {
ans += calc(u, 0);
done[u] = true;
for (int i = 0; i < edge[u].size(); i++) {
int v = edge[u][i].first;
int w = edge[u][i].second;
if (done[v]) continue;
ans -= calc(v, w);
max_son[0] = size = num[v];
get_root(v, root = 0);
solve(root);
}
}

int main() {
ios::sync_with_stdio(false);
while (cin >> n >> K) {
if (n == 0 && K == 0) break;
memset(done, 0, sizeof(done));
for (int i = 1; i <= n; i++) edge[i].clear();
for (int i = 1; i < n; i++) {
int u, v, w;
cin >> u >> v >> w;
edge[u].push_back(make_pair(v, w));
edge[v].push_back(make_pair(u, w));
}
max_son[0] = size = n;
root = 0;
get_root(1, -1);
ans = 0;
solve(root);
cout << ans << endl;
}
}

CF 715C

给出一棵树,每条边有边权,定义$dist(u, v)$的值为将从u到v经过的点权依次写成10进制数,问多少个序偶 $(u, v)$ 使得 $dist(u, v) \equiv 0 \mod m$。

节点数不超过 $10^5$,边权范围为$1-9$,$m \le 10^9$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAX = 112345;
int MOD;

void exgcd(ll a, ll b, ll &d, ll &x, ll &y) {
if (!b) {
d = a;
x = 1;
y = 0;
}
else {
exgcd(b, a % b, d, y, x);
y -= x * (a / b);
}
}

ll inv(ll a, ll p) {
ll d, x, y;
exgcd(a, p, d, x, y);
return d == 1 ? (x + p) % p : -1;
}

ll pow_mod(ll a, ll k) {
if (MOD == 1) return 1;
ll rst = 1;
while (k) {
if (k & 1) rst = rst * a % MOD;
a = a * a % MOD;
k >>= 1;
}
return rst;
}

vector<pair<int, int> > edge[MAX];

int siz, root;
ll ans;
int num[MAX], max_son[MAX];
bool done[MAX];

void get_root(int u, int fa) {
max_son[u] = 0;
num[u] = 1;
for (const auto &e : edge[u]) {
int v = e.first;
if (done[v] || v == fa) continue;
get_root(v, u);
num[u] += num[v];
max_son[u] = max(max_son[u], num[v]);
}
max_son[u] = max(max_son[u], siz - num[u]);
if (max_son[u] < max_son[root]) root = u;
}

ll pow_num[MAX], rev_pow[MAX];
map<ll, ll> pre;

void init() {
pow_num[0] = 1;
rev_pow[0] = 1;
for (int i = 1; i < MAX; i++) pow_num[i] = pow_num[i - 1] * 10 % MOD;
for (int i = 1; i < MAX; i++) rev_pow[i] = inv(pow_num[i], MOD);
}

void add(ll t) {
if (pre.count(t)) pre[t]++;
else pre[t] = 1;
}

ll get(ll t) {
if (pre.count(t)) return pre[t];
return 0;
}

ll rst;

void dfs1(int u, int fa, int len, ll now) {
num[u] = 1;
for (const auto &e : edge[u]) {
int v = e.first;
int w = e.second;
if (done[v] || v == fa) continue;
ll t = (now + pow_num[len] * w) % MOD;
add(t);
dfs1(v, u, len + 1, t);
num[u] += num[v];
}
}

void dfs2(int u, int fa, ll now, int len) {
for (const auto &e : edge[u]) {
int v = e.first;
int w = e.second;
if (done[v] || v == fa) continue;
ll t = (now * 10 + w) % MOD;
rst += get((MOD - t) * rev_pow[len] % MOD);
dfs2(v, u, t, len + 1);
}
}

ll calc(int u, int init) {
rst = 0;
pre.clear();
dfs1(u, 0, 0, 0);
if (init) {
ll t = (init * 10 + init) % MOD;
add(0);
rst += get((MOD - t) * rev_pow[2] % MOD);
dfs2(u, 0, t, 3);
} else {
rst += get(0);
add(0);
dfs2(u, 0, 0, 1);
}
return rst;
}

ll solve(int u) {
ans += calc(u, 0);
done[u] = true;
for (const auto &e : edge[u]) {
int v = e.first;
int w = e.second;
if (done[v]) continue;
ans -= calc(v, w);
max_son[0] = siz = num[v];
get_root(v, root = 0);
solve(root);
}
return ans;
}

int main() {
ios::sync_with_stdio(false);
int n, m;
cin >> n >> m;
MOD = m;

init();

for (int i = 1; i < n; i++) {
int u, v, w;
cin >> u >> v >> w;
u++, v++;
edge[u].push_back(make_pair(v, w));
edge[v].push_back(make_pair(u, w));
}
max_son[0] = siz = n;
get_root(1, root = 0);
cout << solve(root) << endl;
}