P6329 【模板】点分树 | 震波

看到模板两字,先建出点分树吧

就是在点分治的过程把每个点记录下来,然后另开一张图记录下来

然后再在新图上dfs一遍,我们把每个点的子树全部压入一个线段树

由于我们一些奇怪的原因只能用动态开点线段树 保证空间合法,下标是子树内所有点到他的距离

然后查询的时候暴力向父亲跳,然后每次log查询一下,注意容斥容斥,就是减去跳上去的儿子距离他小的部分

但是这个容斥并不简单,所以要对于每个点再开一个线段树记录子树内所有点到(点分树父亲的距离)为下标的权值

算法流程:

  1. 建出点分树

  2. 在点分树上用O(sizlogsiz)O(\sum sizlogsiz)的复杂度建出两棵线段树,一个记录到父亲的一个记录到自己的

  3. 查询和修改的时候暴力跳父亲

extra细节:

建点分树的时候第一步要把全局中心找出来当做根

查询时如果跳到某个点距离大于了阈值不能停下,继续向上跳,因为可能某个祖先的距离到他又合法了...

修改的时候要改两棵树

code:


#include<bits/stdc++.h>
using namespace std;
const int MAXN = 3e5 + 7;
const int MAXM = 7e5 + 7;
const int MAXT = 2e7 + 7;
int n, m, a[MAXN];
int ccnt, home[MAXN], nxt[MAXM], to[MAXM];

inline void ct(int x, int y) {
	ccnt++;
	nxt[ccnt] = home[x];
	home[x] = ccnt;
	to[ccnt] = y;
}

int first[MAXN], cccnt, dep[MAXN];
struct rec {
	int to, nxt;
} e[MAXM];
inline void ct2(int x, int y) {
	cccnt++;
	e[cccnt].nxt = first[x];
	first[x] = cccnt;
	e[cccnt].to = y;
}

int rt, S, RT, rt2;
int siz[MAXN], dp[MAXN], vis[MAXN];
inline void getroot(int u, int F) {
	siz[u] = 1;
	dp[u] = 0;
	for(int i = home[u]; i; i = nxt[i]) {
		int v = to[i];
		if(v == F || vis[v])continue;
		getroot(v, u);
		siz[u] += siz[v];
		dp[u] = max(dp[u], siz[v]);
	}
	dp[u] = max(dp[u], S - siz[u]);
	if(dp[rt] > dp[u])
		rt = u;
	return;
}

inline void solve(int u) {
	vis[u] = 1;
	for(int i = home[u]; i; i = nxt[i]) {
		int v = to[i];
		if(vis[v])continue;
		rt = 0;
		dp[rt] = n;
		S = siz[v];
		getroot(v, u);
		ct2(u, rt);
		ct2(rt, u);
		// printf("%d %d??\n", u, rt);
		solve(rt);
	}
	return ;
}

int fa[MAXN], son[MAXN], top[MAXN];
inline void dfs1(int u, int F) {
	siz[u] = 1;
	fa[u] = F;
	for(int i = home[u]; i; i = nxt[i]) {
		int v = to[i];
		if(v == F)continue;
		dep[v] = dep[u] + 1;
		dfs1(v, u);
		siz[u] += siz[v];
		if(siz[v] > siz[son[u]])son[u] = v;
	}
	return ;
}

inline void dfs2(int u, int topf) {
	top[u] = topf;
	if(!son[u])return ;
	dfs2(son[u], topf);
	for(int i = home[u]; i; i = nxt[i]) {
		int v = to[i];
		if(v == fa[u] || v == son[u])continue;
		dfs2(v, v);
	}
	return ;
}

inline int LCA(int x, int y) {
	while(top[x] != top[y]) {
		// printf("%d?%d\n", x, y);
		if(dep[top[x]] < dep[top[y]])x ^= y ^= x ^= y;
		x = fa[top[x]];
	}
	if(dep[x] > dep[y])x ^= y ^= x ^= y;
	return x;
}

inline void init() {
	dep[1] = 1;
	dfs1(1, 0);//预处理LCA相关
	dfs2(1, 1);
	S = n;
	dp[rt = 0] = n;
	getroot(1, 0);
	RT = rt;
	solve(rt);//建出点分树
	return ;
}

int root[MAXN];
namespace seg {
	int ls[MAXT], rs[MAXT], T, sum[MAXT];
#define mid ((l+r)>>1)
	inline void modify(int &k, int l, int r, int pos, int V) {
		if(!k)k = ++T;
		if(l == r) {
			sum[k] += V;
			return ;
		}
		if(pos <= mid)modify(ls[k], l, mid, pos, V);
		else modify(rs[k], mid + 1, r, pos, V);
		sum[k] = sum[ls[k]] + sum[rs[k]];
	}
	inline int query(int k, int l, int r, int x, int y) {
		if(x <= l && y >= r) {
			return sum[k];
		}
		if(y <= mid)return query(ls[k], l, mid, x, y);
		else if(x > mid)return query(rs[k], mid + 1, r, x, y);
		else return query(ls[k], l, mid, x, y) + query(rs[k], mid + 1, r, x, y);
	}
}
using namespace seg;
int  fa2[MAXN];

inline int DIS(int x, int y) {
	return dep[x] + dep[y] - 2 * dep[LCA(x, y)];
}

inline void add(int u, int F) {
	// printf("in- > %d %d %d %d %d\n", u, rt, S, DIS(u, rt), a[u]);
	// printf("try!%d %d %d %d?\n", u, F, DIS(u, rt), a[u]);
	modify(root[rt], 0, S, DIS(u, rt2), a[u]);
	// printf("query ->:%d?\n", query(root[rt], 0, S, 0, S));
	for(int i = first[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == F)continue;
		add(v, u);
	}
	return;
}

inline void dfs3(int u, int F) {
	fa2[u] = F;
	rt = u;
	S = siz[u];
	rt2 = u;
	// printf("in->%d?%d %d\n", u, F, root[u]);
	add(u, F);
	S = siz[F];
	rt = u + n;
	rt2 = F;
	add(u, F);//到父亲的答案...
	// printf("%d %d?\n", u, F);
	for(int i = first[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == F)continue;
		dfs3(v, u);
	}
	return ;
}

int main() {
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; ++i)scanf("%d", a + i);
	for(int i = 2, u, v; i <= n; ++i) {
		scanf("%d%d", &u, &v);
		ct(u, v);
		ct(v, u);
	}
	init();
	dfs3(RT, 0);//dfs new tree
	int lsa = 0;
	for(int i = 1, x, y, z; i <= m; ++i) {
		scanf("%d%d%d", &z, &x, &y);
		x ^= lsa;
		y ^= lsa;
		if(z == 1) {
			int anc = x;
			while(anc) {
				// printf("%d??\n", DIS(x, anc));
				modify(root[anc], 0, siz[anc], DIS(x, anc), y - a[x]);
				if(fa2[anc])
					modify(root[anc + n], 0, siz[fa2[anc]], DIS(x, fa2[anc]), y - a[x]);
				anc = fa2[anc];
			}
			a[x] = y;
		} else {
			int anc = x, tmp = 0, lst = 0;
			lsa = 0;
			while(anc) {
				tmp = y - DIS(x, anc);
				// printf("now step : anc%d  lstanc:%d  jl:%d\n", anc, lsa, tmp);
				if(tmp < 0) {
					lst = anc;
					anc = fa2[anc];
					continue;
				}
				lsa += query(root[anc], 0, siz[anc], 0, tmp);
				// printf("ans is ?:%d ", query(root[anc], 0, siz[anc], 0, tmp));

				// printf("- %d??\n", query(root[anc + n], 0, siz[fa2[anc]], 0, tmp));
				if(lst)
					lsa -= query(root[lst + n], 0, siz[anc], 0, tmp);
				lst = anc;
				anc = fa2[anc];
			}
			printf("%d\n", lsa);
		}
	}
	return 0;
}