CF504E Misha and LCP on Tree

IOI2020集训队作业
- 给定一棵 个节点的树,每个节点有一个小写字母。
- 有 组询问,每组询问为树上 和 组成的字符串的最长公共前缀。
- ,。
树上字符串QAQ先考虑序列上怎么做吧
相当于区间最长公共前缀,可以二分哈希值来做
放到树上好像就是怎么快速得到一段路径的哈希值
这是一个trick呢,维护每个节点到根的正串(root->x)hash值和反串(x->root)hash值
那么从x->y路径hash值就是先找到LCA,然后dep[x]-dep[LCA]这么长的反串hash值截取出来,然后把dep[y]-dep[LCA]这么长的正串hash值截取出来,二者拼接一下就行了
不过截取是实现不了的,所以我们考虑用二分时的树上k级祖先来代替
比如二分到前L位置,那么先判断L在x->LCA还是LCA->y上,然后一个k级祖先把它找到拿k级祖先处的字符串来比较就可以了
看代码吧qwq
code:
#include<bits/stdc++.h>
#define lg2(x)(31-__builtin_clz(x))
using namespace std;
#define ll long long
const int MAXN=3e5+7;
const int md1=1004535809,md2=167772161;
int base1,base2;
int _1[MAXN],_2[MAXN],n,m;
int i1[MAXN],i2[MAXN];
char s[MAXN];
int home[MAXN],ccnt;
int mxd[MAXN],toph[MAXN],tail[MAXN],sonh[MAXN],fa[MAXN],len[MAXN],F[20][MAXN],dep[MAXN],idx,dfn[MAXN];
int *up[MAXN],*down[MAXN];
int st[20][MAXN<<1];
inline int inv(int a,const int &P) {
ll ans=1,b=P-2;
while(b) {
if(b&1)ans=ans*a%P;
a=1ll*a*a%P;
b>>=1;
}
return ans;
}
struct edge {
int to,nxt;
} e[MAXN<<1];
struct data {
int s1,s2;
int len;
inline void push_front(char c) {
s1=(s1+1ll*c*_1[len])%md1,s2=(s2+1ll*c*_2[len++])%md2;
}
inline void push_back(char c) {
s1=(1ll*s1*base1+c)%md1,s2=(1ll*s2*base2+c)%md2,++len;
}
inline data operator+(const data &rhs)const {
return (data) {
(1ll*s1*_1[rhs.len]+rhs.s1)%md1,(1ll*s2*_2[rhs.len]+rhs.s2)%md2,len+rhs.len
};
}
inline bool operator==(const data &rhs)const {
return s1==rhs.s1&&s2==rhs.s2&&len==rhs.len;
}
} a[MAXN],b[MAXN];
inline void dfs(int nw) {
st[0][dfn[nw]=++idx]=nw;
sonh[nw]=0,mxd[nw]=dep[nw];//mxd就是maxdep
a[nw]=a[fa[nw]],b[nw]=b[fa[nw]];
a[nw].push_back(s[nw]),b[nw].push_front(s[nw]);
for(int i=home[nw]; i; i=e[i].nxt) {
int v=e[i].to;
if(!dep[v]) {
dep[v]=dep[nw]+1;
fa[v]=F[0][v]=nw;
dfs(v);
st[0][++idx]=nw;
if(mxd[v]>mxd[nw])mxd[nw]=mxd[v],sonh[nw]=v;
}
}
}
inline void dfs2(int nw) {
tail[toph[nw]]=nw;
len[toph[nw]]=dep[nw]-dep[toph[nw]]+1;
if(sonh[nw])toph[sonh[nw]]=toph[nw],dfs2(sonh[nw]);
for(int i=home[nw]; i; i=e[i].nxt) {
int v=e[i].to;
if(dep[v]>dep[nw]&&v!=sonh[nw])dfs2(toph[v]=v);
//标准长链剖分
}
}
inline int kfa(int x,int k) {
if(dep[x]<=k)return 0;
if(k==0)return x;
const int lg=lg2(k);
x=F[lg][x],k-=1<<lg;
if(!k)return x;
const int dlt=dep[x]-dep[toph[x]];
if(dlt>=k)return down[toph[x]][dlt-k];
return up[toph[x]][k-dlt];
}
inline int LCA(int x,int y) {
if(dfn[x]>dfn[y])swap(x,y);
x=dfn[x],y=dfn[y];
int lg=lg2(y-x+1);
int a=st[lg][x],b=st[lg][y-(1<<lg)+1];
return (dep[a]<dep[b])?a:b;
}
inline void init() {
for(int i=1; i<19; ++i) {
for(int j=1; j<=n; ++j) {
F[i][j]=F[i-1][F[i-1][j]];//倍增处理
}
}
for(int i=1; i<=n; ++i) {
if(toph[i]==i) {
int h=len[i];
up[i]=new int[h+2];
down[i]=new int [h+2];//动态开空间处理
*up[i]=*down[i]=i;
for(int j=1,nw=fa[i]; j<=h&&nw; ++j,nw=fa[nw]) {
up[i][j]=nw;
}
for(int j=1,nw=sonh[i]; j<=h&&nw; ++j,nw=sonh[nw]) {
down[i][j]=nw;
}
}
}
for(int i=1; i<20; ++i) {
for(int j=1; j<=2*n; ++j) {
int x=st[i-1][j],y=st[i-1][j+(1<<i-1)];
st[i][j]=(dep[x]<dep[y])?x:y;//st±í¿ì°¡
}
}
}
inline data get(int u,int v) {
const data &x=b[u],&y=b[v];
return (data) {
1ll*(x.s1-y.s1+md1)*i1[y.len]%md1,1ll*(x.s2-y.s2+md2)*i2[y.len]%md2,x.len-y.len//暴力截取
};
}
inline data get_(int u,int v) {
const data&x=a[u],&y=a[v];
return (data) {
(y.s1-1ll*x.s1*_1[y.len-x.len]%md1+md1)%md1,(y.s2-1ll*x.s2*_2[y.len-x.len]%md2+md2)%md2,y.len-x.len//暴力截取2
};
}
int main() {
srand(time(0));
base1=rand()%100+200,base2=rand()%300+400;
scanf("%d%s",&n,s+1);
_1[0]=_2[0]=i1[0]=i2[0]=1;
for(int i=1; i<=n; ++i)_1[i]=1ll*_1[i-1]*base1%md1,_2[i]=1ll*_2[i-1]*base2%md2;
i1[1]=inv(base1,md1);
i2[1]=inv(base2,md2);
for(int i=2; i<=n; ++i) {
i1[i]=1ll*i1[i-1]*i1[1]%md1,i2[i]=1ll*i2[i-1]*i2[1]%md2;
}
for(int i=1; i<n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
e[++ccnt]=(edge) {v,home[u]},home[u]=ccnt;
e[++ccnt]=(edge) {u,home[v]},home[v]=ccnt;
}
dfs(dep[1]=1),dfs2(toph[1]=1);
init();
//for(int i=1;i<=n;++i)printf("%d %d %d\n",dep[i],toph[i],fa[i]);
//puts("qwq");
for(scanf("%d",&m); m--;) {
int u1,v1,u2,v2;
scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
int L1=LCA(u1,v1),L2=LCA(u2,v2);
// printf("%d %d\n",L1,L2);
const int D1=dep[u1]+dep[v1]-2*dep[L1]+1,D2=dep[u2]+dep[v2]-2*dep[L2]+1;//找到长度
int l=1,r=min(D1,D2),ans=0;
while(l<=r) {
const int mid=l+r>>1;
const data _x=dep[u1]-mid+1>=dep[L1]?get(u1,kfa(u1,mid)):get(u1,L1)+get_(fa[L1],kfa(v1,D1-mid));//要么在LCA左边路径,要么左边全部和右边k级祖先反串
const data _y=dep[u2]-mid+1>=dep[L2]?get(u2,kfa(u2,mid)):get(u2,L2)+get_(fa[L2],kfa(v2,D2-mid));
if(_x==_y)l=(ans=mid)+1;
else r=mid-1;
}//即可
printf("%d\n",ans);
}
}