P4689 [Ynoi2016]这是我自己的发明

Ynoi2016不知哪道题

俗话说得好在你膨胀的时候做一下Ynoi,立刻就能让你感到自己的弱小

如果这种难度的数据结构也能秒掉那你可以PKUWC2020AK了

1 x 将树根换为 xx

2 x y 给出两个点 x,yx,y,从 xx 的子树中选每一个点,yy 的子树中选每一个点,求点权相等的情况数。

首先我们要有一个第一选择,就是大体上有什么数据结构做这个事情

你会敏锐的发现好像两两点权相等总方案只适合带根号做法......所以我们要树上莫队

但是怎么莫队呢??我们先搞出dfs序然后再换根的前提下就是两个区间的查询

莫队的询问信息不能有四维啊....所以我们先考虑差分一下变成两维

f(l1,r1,l2,r2)=f(1,r1,1,r2)f(1,l11,1,r2)f(1,r1,1,l21)+f(1,l11,1,l21)f(l1,r1,l2,r2)=f(1,r1,1,r2)-f(1,l1-1,1,r2)-f(1,r1,1,l2-1)+f(1,l1-1,1,l2-1)

所以这个就变成了一个二维的问题,虽然询问数*4了

有了这个莫队部分就很显然了

再来考虑换根....哎好像是假的?

情况一 x=root,很显然此时应当查询整棵树。
qwq
情况二 lca(root,x)!=x ,此时直接查询x的子树即可,与换根无关。
qwq
情况三,lca(root,x)=x,此时我们应当查询与x相邻的节点中与root最近的点v在整棵树中的补集
qwq
可以发现v一定在root到x的链上,且一定是x在这条链上的儿子,倍增法可以求得v

好的,也就是说如果有情况3我们原本两段dfs区间会变成4段,再差分变成16个询问QAQ

所以这个就是卡常毒瘤啊,40多行编译优化

code:


#pragma G++ optimize(2)
#pragma G++ optimize(3)

#pragma G++ target("avx")
#pragma G++ optimize("Ofast")
#pragma G++ optimize("inline")
#pragma G++ optimize("-fgcse")
#pragma G++ optimize("-fgcse-lm")
#pragma G++ optimize("-fipa-sra")
#pragma G++ optimize("-ftree-pre")
#pragma G++ optimize("-ftree-vrp")
#pragma G++ optimize("-fpeephole2")
#pragma G++ optimize("-ffast-math")
#pragma G++ optimize("-fsched-spec")
#pragma G++ optimize("unroll-loops")
#pragma G++ optimize("-falign-jumps")
#pragma G++ optimize("-falign-loops")
#pragma G++ optimize("-falign-labels")
#pragma G++ optimize("-fdevirtualize")
#pragma G++ optimize("-fcaller-saves")
#pragma G++ optimize("-fcrossjumping")
#pragma G++ optimize("-fthread-jumps")
#pragma G++ optimize("-funroll-loops")
#pragma G++ optimize("-fwhole-program")
#pragma G++ optimize("-freorder-blocks")
#pragma G++ optimize("-fschedule-insns")
#pragma G++ optimize("inline-functions")
#pragma G++ optimize("-ftree-tail-merge")
#pragma G++ optimize("-fschedule-insns2")
#pragma G++ optimize("-fstrict-aliasing")
#pragma G++ optimize("-fstrict-overflow")
#pragma G++ optimize("-falign-functions")
#pragma G++ optimize("-fcse-skip-blocks")
#pragma G++ optimize("-fcse-follow-jumps")
#pragma G++ optimize("-fsched-interblock")
#pragma G++ optimize("-fpartial-inlining")
#pragma G++ optimize("no-stack-protector")
#pragma G++ optimize("-freorder-functions")
#pragma G++ optimize("-findirect-inlining")
#pragma G++ optimize("-fhoist-adjacent-loads")
#pragma G++ optimize("-frerun-cse-after-loop")
#pragma G++ optimize("inline-small-functions")
#pragma G++ optimize("-finline-small-functions")
#pragma G++ optimize("-ftree-switch-conversion")
#pragma G++ optimize("-foptimize-sibling-calls")
#pragma G++ optimize("-fexpensive-optimizations")
#pragma G++ optimize("-funsafe-loop-optimizations")
#pragma G++ optimize("inline-functions-called-once")
#pragma G++ optimize("-fdelete-null-pointer-checks")


#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;
const int MAXN = 5e5 + 20;
namespace fastIO {
#define BUF_SIZE (1<<19)
	static char buf[BUF_SIZE], *p1 = buf + BUF_SIZE, *pend = buf + BUF_SIZE;
	inline char nc() {
		if(p1 == pend) {
			p1 = buf;
			pend = buf + fread(buf, 1, BUF_SIZE, stdin);
		}
		return *p1++;
	}
	inline int read() {
		int x = 0, f = 1;
		register char s = nc();
		for(; !isdigit(s); s = nc())if(s == '-')f = -1;
		for(; isdigit(s); s = nc())x = (x << 1) + (x << 3) + s - '0';
		return x * f;
	}
}
using namespace fastIO;
void write(ll x) {
	if(x > 9) write(x / 10);
	putchar(x % 10 + 48);
}
int ccnt;
struct edge {
	int to, nxt;
} e[MAXN];

struct Qry {
	int l, r, id, f;
} q[8000010];

int Q, m, in[MAXN], out[MAXN], idx, _dfn, st[MAXN][21], dfn[MAXN], dep[MAXN], fa[MAXN], home[MAXN],
	rt, cnt[MAXN][2], is_q[MAXN], top[MAXN], son[MAXN], siz[MAXN], tl[MAXN], tr[MAXN], o[MAXN], c[MAXN], a[MAXN], n,
	lg2[MAXN];
ll ans[MAXN], ima, _c, SIZ;

void ct(int u, int v) {
	e[++ccnt] = (edge) {
		v, home[u]
	};
	home[u] = ccnt;
}

bool cmp(Qry a, Qry b) {
	return a.l / SIZ == b.l / SIZ ? a.r < b.r : a.l < b.l;
}

void dfs1(int u, int F) {
	st[++idx][0] = u;
	in[u] = ++_dfn;
	dfn[u] = idx;
	c[_dfn] = a[u];
	dep[u] = dep[F] + 1;
	fa[u] = F;
	siz[u] = 1;
	for(int i = home[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == F)continue;
		dfs1(v, u);
		st[++idx][0] = u;
		siz[u] += siz[v];
		if(siz[v] > siz[son[u]])son[u] = v;
	}
	out[u] = _dfn;
	//printf("%d %d %d %d %d %d\n", u, dfn[u], in[u], dep[u], son[u], fa[u]);
}

void dfs2(int u, int topf) {
	top[u] = topf;
	if(son[u])dfs2(son[u], topf);
	for(int i = home[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v != son[u] && v != fa[u])
			dfs2(v, v);
	}
	//	printf("%d %d\n ", u, top[u]);
}
//找一下x在y的哪一个子树里面
int find(int y, int x) {
	int u;
	while(top[x] != top[y]) {
		u = top[x];
		x = fa[u];
	}
	return x == y ? u : son[y];
}

void init() {
	dep[1] = 1;
	dfs1(1, 1);
	dfs2(1, 1);
	for(int i = 2; i <= (n << 1); ++i)lg2[i] = lg2[i >> 1] + 1;
	for(int j = 1; j <= 21; ++j) {
		for(int i = 1; i + (1 << j) - 1 <= (n << 1); ++i) {
			st[i][j] = dep[st[i][j - 1]] < dep[st[i + (1 << (j - 1))][j - 1]] ? st[i][j - 1] : st[i + (1 << (j - 1))][j - 1];
		}
	}
}

int LCA(int u, int v) {
	int x = dfn[u], y = dfn[v];
	if(x > y)swap(x, y);
	int k = lg2[y - x + 1];
	return dep[st[x][k]] < dep[st[y - (1 << k) + 1][k]] ? st[x][k] : st[y - (1 << k) + 1][k];
}

void RC(int l1, int r1, int l2, int r2, int id) {
	q[++m] = (Qry) {
		r1, r2, id, 1
	};
	q[++m] = (Qry) {
		l1 - 1, l2 - 1, id, 1
	};
	q[++m] = (Qry) {
		r1, l2 - 1, id, -1
	};
	q[++m] = (Qry) {
		l1 - 1, r2, id, -1
	};
	//二维前缀和/jk
}

void devide(int x) {
	if(x == rt)tl[++_c] = 1, tr[_c] = n;//哇偶我就是根
	else {
		int z = LCA(x, rt);
		//	printf("%d %d %dQAQ\n", x, rt, z);
		if(z != x) {
			tl[++_c] = in[x];
			tr[_c] = out[x];
			//相当于还是你自己整棵
		} else {
			int y = find(x, rt);
			//找到这个儿子
			if(1 <= in[y] - 1)tl[++_c] = 1;
			tr[_c] = in[y] - 1;
			//我们先把第一段这个加进去
			//这一部分要减去
			if(out[y] + 1 <= n)tl[++_c] = out[y] + 1, tr[_c] = n;
			//再把第二段加进去
			//outy+1就对应了另一段,n是最后一个
		}
	}
}

void build(int x, int y, int id) {
	_c = 0;
	//	printf("%d %d %d?\n", x, y, id);
	devide(x);
	int mid = _c;
	devide(y);
	//下面是压行写法233
	for(int i = 1; i <= mid; ++i) {
		for(int j = mid + 1; j <= _c; ++j) {
			//	printf("%d %d %d %d &\n", tl[i], tr[i], tl[j], tr[j]);
			RC(tl[i], tr[i], tl[j], tr[j], id);
			//把这些信息计入Q
		}
	}
}

void add(int x, int p) {
	ima += cnt[c[x]][p ^ 1];
	++cnt[c[x]][p];
}
void del(int x, int p) {
	ima -= cnt[c[x]][p ^ 1];
	--cnt[c[x]][p];
}


int main() {
	//freopen("test.in", "r", stdin);
	n = read();
	Q = read();
	SIZ = sqrt(n);
	for(register int i = 1; i <= n; ++i)a[i] = read(), o[i] = a[i];
	sort(o + 1, o + n + 1);
	int _n = unique(o + 1, o + n + 1) - o - 1;
	for(register int i = 1; i <= n; ++i)a[i] = lower_bound(o + 1, o + _n + 1, a[i]) - o;
	for(register int u, v, i = 1; i < n; ++i)u = read(), v = read(), ct(u, v), ct(v, u);
	init();
	rt = 1;
	for(register int opt, x, y, i = 1; i <= Q; ++i) {
		opt = read();
		x = read();
		if(opt == 1)rt = x;
		else is_q[i] = 1, y = read(), build(x, y, i);
	}
	for(int i = 1; i <= m; ++i)if(q[i].l > q[i].r)swap(q[i].l, q[i].r);
	sort(q + 1, q + m + 1, cmp);
	for(register int L = 0, R = 0, i = 1; i <= m; ++i) {
		int l = q[i].l, r = q[i].r;
		//	printf("%d %d %d %d\n", q[i].l, q[i].r, q[i].id, q[i].f);
		while(L < l)add(++L, 0);
		while(L > l)del(L--, 0);
		while(R < r)add(++R, 1);
		while(R > r)del(R--, 1);
		ans[q[i].id] += ima * q[i].f;
	}
	for(int i = 1; i <= Q; ++i)if(is_q[i])write(ans[i]), putchar('\n');
	return 0;
}