CF600E Lomsat gelral

Dsu on tree 好题

一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。

首先要读懂题,题目是指那个颜色的编号和!

然后我们可以熟悉的使用dsu on tree + 线段树来做这个题qwq

具体怎么用线段树维护?如果数量相等就求个和,否则我们就选择大的继承

别的真没什么了

dsu流程:首先dfs轻儿子

其次加入重儿子的影响

然后把这个点的轻儿子都加进去

算这个点答案

如果这个点是轻链上的就不用鸟,否则删除加进去的贡献

code

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using std::max;
const int MAXN=5e5+7;
int n,ccnt;
int a[MAXN],home[MAXN],nxt[MAXN],to[MAXN];

#define ad2(x,y) (ct(x,y),ct(y,x))
inline void ct(int x,int y) {
	ccnt++;
	nxt[ccnt]=home[x];
	home[x]=ccnt;
	to[ccnt]=y;
}

struct rec {
	int Max;
	ll sum;
	bool operator<(const rec &x) const {
		return Max<x.Max;
	}
} tr[MAXN];
int root;
namespace seg {
#define mid ((l+r)>>1)
	static int T,ls[MAXN],rs[MAXN];
	inline void modify(int &k,int l,int r,int pos,int val) {
		if(!k)k=++T;
		if(l==r) {

			tr[k].Max+=val;
			tr[k].sum=l;

			return ;
		}

		if(pos<=mid)modify(ls[k],l,mid,pos,val);
		else modify(rs[k],mid+1,r,pos,val);

//		printf("%d %d %d %d %d %d\n",l,r,tr[ls[k]].Max,tr[ls[k]].sum,tr[rs[k]].Max,tr[rs[k]].sum);

		if(tr[ls[k]].Max<tr[rs[k]].Max) {
			tr[k].Max=tr[rs[k]].Max;
			tr[k].sum=tr[rs[k]].sum;
		} else if(tr[ls[k]].Max==tr[rs[k]].Max) {
			tr[k].Max=tr[ls[k]].Max;
			tr[k].sum=tr[ls[k]].sum+tr[rs[k]].sum;
		} else {
			tr[k].Max=tr[ls[k]].Max;
			tr[k].sum=tr[ls[k]].sum;
		}
//		printf("result:%d %d\n",tr[k].Max,tr[k].sum);
	}
}
int S;
int cnt[MAXN],son[MAXN],siz[MAXN];
ll ans[MAXN];

inline void dfs1(int u,int F) {
	siz[u]=1;
	int maxson=-1;
	for(int i=home[u]; i; i=nxt[i]) {
		int v=to[i];
		if(v==F)continue;
		dfs1(v,u);
		siz[u]+=siz[v];
		if(siz[v]>maxson) {
			maxson=siz[v];
			son[u]=v;
		}
	}
	return ;
}

inline void add(int u,int F,int t) {
	seg::modify(root,1,n,a[u],t);
	for(int i=home[u]; i; i=nxt[i]) {
		int v=to[i];
		if(v==F||v==S)continue;
		add(v,u,t);
	}
}

inline void dfs(int u,int F,int keep) {

	for(int i=home[u]; i; i=nxt[i]) {
		int v=to[i];
		if(v==F||v==son[u])continue;
		dfs(v,u,0);
	}

	if(son[u])dfs(son[u],u,1);
	S=son[u];
	add(u,F,1);
//	printf("%d\n",tr[root].);
//	puts("!!!!!");
	seg::modify(root,1,n,1,0);
	ans[u]=tr[root].sum;
//	puts("QWQWQWQWQ");
	S=0;
	if(!keep)add(u,F,-1);
	return;
}

int main() {
//	freopen("test.in","r",stdin);
	scanf("%d",&n);
	for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
	for(int i=1,x,y; i<n; ++i) {
		scanf("%d%d",&x,&y);
		ad2(x,y);
	}
	dfs1(1,0);
//	for(int i=1; i<=n; ++i)printf("%dqwq\n",son[i]);
	dfs(1,0,1);
	for(int i=1; i<=n; ++i)printf("%lld ",ans[i]);
	return 0;
}