CF504E Misha and LCP on Tree

IOI2020集训队作业

  • 给定一棵 nn个节点的树,每个节点有一个小写字母。
  • mm 组询问,每组询问为树上 aba \to bcdc \to d组成的字符串的最长公共前缀。
  • n3×105n \le 3 \times 10^5m106m \le 10^6

树上字符串QAQ先考虑序列上怎么做吧

相当于区间最长公共前缀,可以O(logn)O(logn)二分哈希值来做

放到树上好像就是怎么快速得到一段路径的哈希值

这是一个trick呢,维护每个节点到根的正串(root->x)hash值和反串(x->root)hash值

那么从x->y路径hash值就是先找到LCA,然后dep[x]-dep[LCA]这么长的反串hash值截取出来,然后把dep[y]-dep[LCA]这么长的正串hash值截取出来,二者拼接一下就行了

不过截取是实现不了的,所以我们考虑用二分时的树上k级祖先来代替

比如二分到前L位置,那么先判断L在x->LCA还是LCA->y上,然后一个k级祖先把它找到拿k级祖先处的字符串来比较就可以了

看代码吧qwq

code:

#include<bits/stdc++.h>
#define lg2(x)(31-__builtin_clz(x))
using namespace std;
#define ll long long

const int MAXN=3e5+7;
const int md1=1004535809,md2=167772161;
int base1,base2;
int _1[MAXN],_2[MAXN],n,m;
int i1[MAXN],i2[MAXN];
char s[MAXN];
int home[MAXN],ccnt;
int mxd[MAXN],toph[MAXN],tail[MAXN],sonh[MAXN],fa[MAXN],len[MAXN],F[20][MAXN],dep[MAXN],idx,dfn[MAXN];
int *up[MAXN],*down[MAXN];
int st[20][MAXN<<1];

inline int inv(int a,const int &P) {
	ll ans=1,b=P-2;
	while(b) {
		if(b&1)ans=ans*a%P;
		a=1ll*a*a%P;
		b>>=1;
	}
	return ans;
}
struct edge {
	int to,nxt;
} e[MAXN<<1];

struct data {
	int s1,s2;
	int len;
	inline void push_front(char c) {
		s1=(s1+1ll*c*_1[len])%md1,s2=(s2+1ll*c*_2[len++])%md2;
	}
	inline void push_back(char c) {
		s1=(1ll*s1*base1+c)%md1,s2=(1ll*s2*base2+c)%md2,++len;
	}
	inline data operator+(const data &rhs)const {
		return (data) {
			(1ll*s1*_1[rhs.len]+rhs.s1)%md1,(1ll*s2*_2[rhs.len]+rhs.s2)%md2,len+rhs.len
		};
	}
	inline bool operator==(const data &rhs)const {
		return s1==rhs.s1&&s2==rhs.s2&&len==rhs.len;
	}
} a[MAXN],b[MAXN];

inline void dfs(int nw) {
	st[0][dfn[nw]=++idx]=nw;
	sonh[nw]=0,mxd[nw]=dep[nw];//mxd就是maxdep
	a[nw]=a[fa[nw]],b[nw]=b[fa[nw]];
	a[nw].push_back(s[nw]),b[nw].push_front(s[nw]);
	for(int i=home[nw]; i; i=e[i].nxt) {
		int v=e[i].to;
		if(!dep[v]) {
			dep[v]=dep[nw]+1;
			fa[v]=F[0][v]=nw;
			dfs(v);
			st[0][++idx]=nw;
			if(mxd[v]>mxd[nw])mxd[nw]=mxd[v],sonh[nw]=v;
		}
	}
}
inline void dfs2(int nw) {
	tail[toph[nw]]=nw;
	len[toph[nw]]=dep[nw]-dep[toph[nw]]+1;
	if(sonh[nw])toph[sonh[nw]]=toph[nw],dfs2(sonh[nw]);
	for(int i=home[nw]; i; i=e[i].nxt) {
		int v=e[i].to;
		if(dep[v]>dep[nw]&&v!=sonh[nw])dfs2(toph[v]=v);
		//标准长链剖分
	}
}

inline int kfa(int x,int k) {
	if(dep[x]<=k)return 0;
	if(k==0)return x;
	const int lg=lg2(k);
	x=F[lg][x],k-=1<<lg;
	if(!k)return x;
	const int dlt=dep[x]-dep[toph[x]];
	if(dlt>=k)return down[toph[x]][dlt-k];
	return up[toph[x]][k-dlt];
}

inline int LCA(int x,int y) {
	if(dfn[x]>dfn[y])swap(x,y);
	x=dfn[x],y=dfn[y];
	int lg=lg2(y-x+1);
	int a=st[lg][x],b=st[lg][y-(1<<lg)+1];
	return (dep[a]<dep[b])?a:b;
}

inline void init() {
	for(int i=1; i<19; ++i) {
		for(int j=1; j<=n; ++j) {
			F[i][j]=F[i-1][F[i-1][j]];//倍增处理
		}
	}
	for(int i=1; i<=n; ++i) {
		if(toph[i]==i) {
			int h=len[i];
			up[i]=new int[h+2];
			down[i]=new int [h+2];//动态开空间处理
			*up[i]=*down[i]=i;
			for(int j=1,nw=fa[i]; j<=h&&nw; ++j,nw=fa[nw]) {
				up[i][j]=nw;
			}
			for(int j=1,nw=sonh[i]; j<=h&&nw; ++j,nw=sonh[nw]) {
				down[i][j]=nw;
			}
		}
	}
	for(int i=1; i<20; ++i) {
		for(int j=1; j<=2*n; ++j) {
			int x=st[i-1][j],y=st[i-1][j+(1<<i-1)];
			st[i][j]=(dep[x]<dep[y])?x:y;//st±í¿ì°¡
		}
	}
}

inline data get(int u,int v) {
	const data &x=b[u],&y=b[v];
	return (data) {
		1ll*(x.s1-y.s1+md1)*i1[y.len]%md1,1ll*(x.s2-y.s2+md2)*i2[y.len]%md2,x.len-y.len//暴力截取
	};
}

inline data get_(int u,int v) {
	const data&x=a[u],&y=a[v];
	return (data) {
		(y.s1-1ll*x.s1*_1[y.len-x.len]%md1+md1)%md1,(y.s2-1ll*x.s2*_2[y.len-x.len]%md2+md2)%md2,y.len-x.len//暴力截取2
	};
}

int main() {
	srand(time(0));
	base1=rand()%100+200,base2=rand()%300+400;
	scanf("%d%s",&n,s+1);
	_1[0]=_2[0]=i1[0]=i2[0]=1;
	for(int i=1; i<=n; ++i)_1[i]=1ll*_1[i-1]*base1%md1,_2[i]=1ll*_2[i-1]*base2%md2;
	i1[1]=inv(base1,md1);
	i2[1]=inv(base2,md2);
	for(int i=2; i<=n; ++i) {
		i1[i]=1ll*i1[i-1]*i1[1]%md1,i2[i]=1ll*i2[i-1]*i2[1]%md2;
	}
	for(int i=1; i<n; ++i) {
		int u,v;
		scanf("%d%d",&u,&v);
		e[++ccnt]=(edge) {v,home[u]},home[u]=ccnt;
		e[++ccnt]=(edge) {u,home[v]},home[v]=ccnt;
	}
	dfs(dep[1]=1),dfs2(toph[1]=1);
	init();
	//for(int i=1;i<=n;++i)printf("%d %d %d\n",dep[i],toph[i],fa[i]);
	//puts("qwq");
	for(scanf("%d",&m); m--;) {
		int u1,v1,u2,v2;
		scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
		int L1=LCA(u1,v1),L2=LCA(u2,v2);
	//	printf("%d %d\n",L1,L2);
		const int D1=dep[u1]+dep[v1]-2*dep[L1]+1,D2=dep[u2]+dep[v2]-2*dep[L2]+1;//找到长度
		int l=1,r=min(D1,D2),ans=0;
		while(l<=r) {
			const int mid=l+r>>1;
			const data _x=dep[u1]-mid+1>=dep[L1]?get(u1,kfa(u1,mid)):get(u1,L1)+get_(fa[L1],kfa(v1,D1-mid));//要么在LCA左边路径,要么左边全部和右边k级祖先反串
			const data _y=dep[u2]-mid+1>=dep[L2]?get(u2,kfa(u2,mid)):get(u2,L2)+get_(fa[L2],kfa(v2,D2-mid));
			if(_x==_y)l=(ans=mid)+1;
			else r=mid-1;
		}//即可
		printf("%d\n",ans);
	}
}