CF600E Lomsat gelral
Dsu on tree 好题
一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
首先要读懂题,题目是指那个颜色的编号和!
然后我们可以熟悉的使用dsu on tree + 线段树来做这个题qwq
具体怎么用线段树维护?如果数量相等就求个和,否则我们就选择大的继承
别的真没什么了
dsu流程:首先dfs轻儿子
其次加入重儿子的影响
然后把这个点的轻儿子都加进去
算这个点答案
如果这个点是轻链上的就不用鸟,否则删除加进去的贡献
code
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using std::max;
const int MAXN=5e5+7;
int n,ccnt;
int a[MAXN],home[MAXN],nxt[MAXN],to[MAXN];
#define ad2(x,y) (ct(x,y),ct(y,x))
inline void ct(int x,int y) {
ccnt++;
nxt[ccnt]=home[x];
home[x]=ccnt;
to[ccnt]=y;
}
struct rec {
int Max;
ll sum;
bool operator<(const rec &x) const {
return Max<x.Max;
}
} tr[MAXN];
int root;
namespace seg {
#define mid ((l+r)>>1)
static int T,ls[MAXN],rs[MAXN];
inline void modify(int &k,int l,int r,int pos,int val) {
if(!k)k=++T;
if(l==r) {
tr[k].Max+=val;
tr[k].sum=l;
return ;
}
if(pos<=mid)modify(ls[k],l,mid,pos,val);
else modify(rs[k],mid+1,r,pos,val);
// printf("%d %d %d %d %d %d\n",l,r,tr[ls[k]].Max,tr[ls[k]].sum,tr[rs[k]].Max,tr[rs[k]].sum);
if(tr[ls[k]].Max<tr[rs[k]].Max) {
tr[k].Max=tr[rs[k]].Max;
tr[k].sum=tr[rs[k]].sum;
} else if(tr[ls[k]].Max==tr[rs[k]].Max) {
tr[k].Max=tr[ls[k]].Max;
tr[k].sum=tr[ls[k]].sum+tr[rs[k]].sum;
} else {
tr[k].Max=tr[ls[k]].Max;
tr[k].sum=tr[ls[k]].sum;
}
// printf("result:%d %d\n",tr[k].Max,tr[k].sum);
}
}
int S;
int cnt[MAXN],son[MAXN],siz[MAXN];
ll ans[MAXN];
inline void dfs1(int u,int F) {
siz[u]=1;
int maxson=-1;
for(int i=home[u]; i; i=nxt[i]) {
int v=to[i];
if(v==F)continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>maxson) {
maxson=siz[v];
son[u]=v;
}
}
return ;
}
inline void add(int u,int F,int t) {
seg::modify(root,1,n,a[u],t);
for(int i=home[u]; i; i=nxt[i]) {
int v=to[i];
if(v==F||v==S)continue;
add(v,u,t);
}
}
inline void dfs(int u,int F,int keep) {
for(int i=home[u]; i; i=nxt[i]) {
int v=to[i];
if(v==F||v==son[u])continue;
dfs(v,u,0);
}
if(son[u])dfs(son[u],u,1);
S=son[u];
add(u,F,1);
// printf("%d\n",tr[root].);
// puts("!!!!!");
seg::modify(root,1,n,1,0);
ans[u]=tr[root].sum;
// puts("QWQWQWQWQ");
S=0;
if(!keep)add(u,F,-1);
return;
}
int main() {
// freopen("test.in","r",stdin);
scanf("%d",&n);
for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
for(int i=1,x,y; i<n; ++i) {
scanf("%d%d",&x,&y);
ad2(x,y);
}
dfs1(1,0);
// for(int i=1; i<=n; ++i)printf("%dqwq\n",son[i]);
dfs(1,0,1);
for(int i=1; i<=n; ++i)printf("%lld ",ans[i]);
return 0;
}