算法笔记:dsu on tree(树上启发式合并)

ACM
4.8k words

之前写过但是vp一把遇到后发现还是没理解清楚,狠补一下这个。

参考:https://zhuanlan.zhihu.com/p/658598885

模板题:树上数颜色

给定一棵根节点为 11 的树,每个节点有不同的颜色。

多次询问,每次询问查询以 uu 节点为根节点的子树中的不同颜色树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
using namespace std;

//#undef LOCAL_DEBUG
#ifdef LOCAL_DEBUG
#include "debug.h"
#else
#define debug(...) (void)0
#define debug_array(arr, len) (void)0
#define debug_container(container) (void)0
#endif

typedef unsigned long long ull;
typedef long long ll;
typedef pair<int,int> P;
const int maxn=1e5;

int n;
vector<int> G[maxn+5];
int color[maxn+5];

int big[maxn+5];
int size[maxn+5];
void rebuild(int u,int fa){
size[u]=1;
for (int v: G[u])
if (v!=fa){
rebuild(v,u);
size[u]+=size[v];
if (size[big[u]]<size[v])
big[u]=v;
}
}

int cnt[maxn+5];
int color_size;
void add(int u,int x){
cnt[color[u]]+=x;
if (cnt[color[u]]==0)
color_size--;
else if (x==1&&cnt[color[u]]==1)
color_size++;
}
void update(int u,int fa,int x){
add(u,x);
for (int v: G[u])
if (v!=fa)
update(v,u,x);
}

int ans[maxn+5];
void dsu_on_tree(int u,int fa,bool keep){
for (int v: G[u])
if (v!=fa&&v!=big[u])
dsu_on_tree(v,u,false);
if (big[u])
dsu_on_tree(big[u],u,true);
add(u,1);
for (int v: G[u])
if (v!=fa&&v!=big[u])
update(v,u,1);
ans[u]=color_size;
if (keep==false){
add(u,-1);
for (int v: G[u])
if (v!=fa)
update(v,u,-1);
}
}

void solve(void){
cin>>n;
for (int i=1,u,v;i<n;i++){
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
for (int i=1;i<=n;i++)
cin>>color[i];
rebuild(1,0);
dsu_on_tree(1,0,true);
int q,u;
cin>>q;
while (q--){
cin>>u;
cout<<ans[u]<<endl;
}
}

int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int t=1;
// cin>>t;
while (t--)
solve();
// cout<<(solve()?"YES":"NO")<<endl;
return 0;
}

L. 彩色的树

也是统计子树颜色,但是这里的答案,要求是每个子树中深度不超过 kk 的。

可以把每个节点与它深度恰好为 k+1k+1 的重子树节点都连起来,然后统计前动态删掉。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <bits/stdc++.h>
using namespace std;

#undef LOCAL_DEBUG
#ifdef LOCAL_DEBUG
#include "debug.h"
#else
#define debug(...) (void)0
#define debug_array(arr, len) (void)0
#define debug_container(container) (void)0
#endif

typedef unsigned long long ull;
typedef long long ll;
typedef pair<int,int> P;
const int maxn=1e5;
const int inf=0x3fffffff;

int n,k;
int color[maxn+5];

vector<int> G[maxn+5],T[maxn+5];

int sz[maxn+5];
int son[maxn+5];
void rebuild(int u,int fa){
sz[u]=1;
for (int v: G[u])
if (v!=fa){
rebuild(v,u);
sz[u]+=sz[v];
if (sz[son[u]]<sz[v])
son[u]=v;
}
}

int line[maxn+5],len;
void connect(int u,int fa){
line[++len]=u;
debug(u,len);
if (len-k-1>0&&son[line[len-k-1]]==line[len-k]){
debug(line[len-k-1],u);
T[line[len-k-1]].push_back(u);
}
for (int v: G[u])
if (v!=fa)
connect(v,u);
len--;
}


int cnt[maxn+5];
int color_sz;

void revise(int u,int x){
int c=color[u];
cnt[c]+=x;
if (cnt[c]==0)
color_sz--;
else if (x==1&&cnt[c]==1)
color_sz++;
}
void update(int u,int fa,int x,int depth,int lim=k){
if (depth>lim)
return ;
revise(u,x);
for (int v: G[u])
if (v!=fa)
update(v,u,x,depth+1,lim);
}

int ans[maxn+5];
void dsu_on_tree(int u,int fa,bool keep){
for (int v: G[u])
if (v!=fa&&v!=son[u])
dsu_on_tree(v,u,false);
if (son[u])
dsu_on_tree(son[u],u,true);
revise(u,1);
for (int v: G[u])
if (v!=fa&&v!=son[u])
update(v,u,1,1);
for (int v: T[u]){
debug(u,v);
revise(v,-1);
}
ans[u]=color_sz;
if (keep==false){
revise(u,-1);
update(son[u],u,-1,1);
for (int v: G[u])
if (v!=fa&&v!=son[u])
update(v,u,-1,1);
}
}

void solve(void){
cin>>n>>k;
for (int i=1;i<=n;i++)
cin>>color[i];
for (int i=1,u,v;i<n;i++){
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
rebuild(1,0);
connect(1,0);
dsu_on_tree(1,0,true);
int q,u;
cin>>q;
while (q--){
cin>>u;
cout<<ans[u]<<endl;
}
}

int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int t=1;
// cin>>t;
while (t--)
solve();
// cout<<(solve()?"YES":"NO")<<endl;
return 0;
}