P3824 [NOI2017]泳池

NOI2017D1T3
强行学了常系数齐次线性递推来做此题QAQ
首先你会发现难点在于怎么入手写这个DP
这个二维的矩阵概率问题我还真没做过其他类型题.......TAT
第一步是恰好的处理方式,有容斥和差分两种
容斥就是直接容斥,差分就是变为极大子矩阵大小小于等于k的概率-极大子矩阵大小小于等于k概率
由于不是方案数所以不太容斥,考虑差分吧
又一想,我们之前做的概率树那道题不就是有个DP状态是大于等于啥啥啥吗?而且那个题还也是个概率期望的,所以DP,而且这个可能写进状态里
设表示宽度为i,底部第i个点恰好为坏点,并且大小的矩阵中不存在大于k的极大子矩阵概率
答案就是因为去掉最后一行硬点的就是前n随便放答案
那么的转移可以枚举底部坏点,
那么好像就是没有硬点的限制下矩阵中没有大于k的极大子矩阵概率
这个显然不好算,考虑转化为一堆数的和,也就是我们加1维能让他转移
宽为i的矩阵,坏点高度最小是那么这个矩阵中不会出现面积大于k的极大子矩阵概率
他的转移是先从左到右枚举坏点位置,再从低到高枚举第一个坏点高度?来作为划分点转移,因为高度*宽度>k的没有意义,所以能够保证的是不会出现面积大于k
注意其中一个时j+1到inf,一个时j到inf,其实是防止算重,因为注意我们枚举的是第一个
坏点的位置,也就是说前面要有坏点必须要高度比他大!所以这个才只能是j+1
这个显然可以后缀和优化,设
那么转移可以改写为
这样其实就是,p就有了此时目标是,令人沮丧的是,你发现
常系数齐次线性递推!!
其实做这个题的模板并不是一件令人开心的事情,调了半天的Rev数组/////
首先可以常系数齐次线性递推的也一定满足这样的形式
其中
那么这个形式的就可以了
因为矩阵乘法慢其实是它记载的信息太多了导致的,我们由之前压缩自动机也能看出,假如我们得到了一个序列,A是转移矩阵,满足
我们考虑实际上我们只需要知道中某一项的值
所以同乘,并用分配率
具体某一项是那一项呢?好像是第一项啊
但是不就对应着a数组递推i次后的值吗?
前k项是知道的呀......所以只要用转移矩阵搞出c序列就够了
构造序列C
嗯,是n次的,两边次数不等,c可能不存在啊
所以构造一下$$A^{n}=Q(A)G(A)+R(A)$$
其次数等于n,次数小于n
然后我们就开始钦定......因为不难发现有多组解吧
并且我们硬点G非常神奇,他能够等于0,换句话说
那么我们就有了
这也和之前的定义满足呢
那么我们怎么求?最简单的一步
$
?标准多项式快速幂并加上多项式取模
,这一步决定了我考场不可能写他
现在我们只需要构造出,然而我们至此还没有用到常系数齐次线性递推
一点性质
所以假设递推系数为
证明不会,做完了泳池还是只需暴力的,优化没意义
code:
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int MAXN=2048;
const ll P=998244353;
int n,k;
ll p,q,x,y;
inline ll ksm(ll x,ll y) {
ll ans=1;
while(y) {
if(y&1)ans=ans*x%P;
y>>=1;
x=x*x%P;
}
return ans;
}
ll sdp[MAXN][MAXN];
ll st[MAXN];
ll ret[MAXN];
ll tr[MAXN];
ll f[MAXN];
ll cp1[MAXN];
ll cp2[MAXN];
ll a[MAXN];
inline ll solve(int k) {
for(int i=0; i<=k+1; ++i)sdp[i][0]=1;
for(int j=k; j>=1; --j) {
for(int i=1; i*j<=k; ++i) {
//i枚举宽度
//j枚举高度
ll ret=0;
for(int t=1; t<=i; ++t) {
(ret+=sdp[j+1][t-1]*sdp[j][i-t])%=P;
//组合起来,注意其中要空出一列
}
ret=ret*p%P*ksm(q,j)%P;
//ret是维护dp用的,最后乘上p,q^j表示在i+1号列放一个,然后其他的那一列都是安全的
sdp[j][i]=(sdp[j+1][i]+ret)%P;
}
}
++k;
tr[1]=p;
//开始常系数齐次线性递推,我不懂了
for(int i=1; i<=k-1; ++i)tr[i+1]=sdp[1][i]*p%P;
st[0]=1;
for(int i=1; i<k; ++i)for(int j=0; j<i; ++j)(st[i]+=st[j]*tr[i-j])%=P;
for(int i=1; i<=k; ++i)f[k-i]=P-tr[i];
f[k]=1;
ret[0]=1;
a[1]=1;
int t=n+1;
//ret是要返回的ans
//cp1和cp2是tmp数组
//a是ksm累乘
while(t) {
if(t&1) {
for(int i=0; i<=k; ++i)cp1[i]=ret[i],ret[i]=0;
for(int i=0; i<=k; ++i)for(int j=0; j<=k; ++j)(ret[i+j]+=cp1[i]*a[j])%=P;
for(int i=2*k; i>=k; --i) {
for(int j=0; j<=k; ++j) {
(ret[i-k+j]+=(P-ret[i]*f[j]%P))%=P;
}
}
//懒人表现,暴力展开....k^2
}
for(int i=0; i<=k; ++i)cp1[i]=a[i];
for(int i=0; i<=k; ++i)cp2[i]=a[i],a[i]=0;
for(int i=0; i<=k; ++i)for(int j=0; j<=k; ++j)(a[i+j]+=cp1[i]*cp2[j])%=P;
for(int i=2*k; i>=k; --i)
for(int j=0; j<=k; ++j)(a[i-k+j]+=P-a[i]*f[j]%P)%=P;
t>>=1;
}
ll ans=0;
for(int i=0; i<k; ++i)(ans+=st[i]*ret[i])%=P;
for(int i=0; i<=k+1; ++i)for(int j=0; j<=k+1; ++j)sdp[i][j]=0;
for(int i=0; i<=k; ++i)a[i]=0;
for(int i=0; i<=k; ++i)ret[i]=0;
for(int i=0; i<=k; ++i)st[i]=0;
return ans*ksm(p,P-2)%P;//最后除以p
}
int main() {
scanf("%d%d%lld%lld",&n,&k,&x,&y);
q=x*ksm(y,P-2)%P;
p=(1+P-q)%P;
printf("%lld",(solve(k)+P-solve(k-1))%P);
return 0;
}
模板,常系数齐次线性递推代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int MAXN=2e5+7;
#define ll long long
const int G0=3;
const int G1=332748118;
const ll P=998244353;
int n,k,R[MAXN];
ll rt[20][20];
int len;
ll tr1[MAXN],tr2[MAXN];
ll st[MAXN],xs[MAXN];
ll sg[MAXN],a[MAXN],res[MAXN],irg[MAXN],q[MAXN],rf[MAXN];
int L=-1;
ll ans=0;
ll ret[MAXN],D[MAXN];
inline ll ksm(ll x,ll y) {
ll ans=1;
while(y) {
if(y&1)ans=ans*x%P;
x=x*x%P;
y>>=1;
}
return ans;
}
inline void calc(int L) {
// printf("%d??\n",L);
for(int i=1; i<(1<<L); i++) {
R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
}
return ;
}
inline void NTT(ll *F,int len,int typ) {
register int i,mid,k,j;
for( i=0; i<len; ++i)if(i<R[i])swap(F[i],F[R[i]]);
// if(typ==3) {
// for(int i=0; i<len; ++i)printf("%d ",F[i]);
// puts("");
// }
for( mid=1; mid<len; mid<<=1) {
ll wn=ksm(typ==-1?G1:G0,(P-1)/(mid<<1));
for(j=0; j<len; j+=(mid<<1)) {
ll w=1;
for( k=0; k<mid; k++,w=w*wn%P) {
ll x=F[j+k];
ll y=F[j+k+mid]*w%P;
F[j+k]=(x+y)%P;
F[j+k+mid]=(x-y+P)%P;
}
}
}
if(typ==-1) {
ll iv=ksm(len,P-2);
for(int i=0; i<len; ++i)(F[i]=F[i]*iv%P);
}
return ;
}
inline void F_inv(ll *A,ll *B,int deg) {
if(deg==1) {
B[0]=ksm(A[0],P-2);
return ;
}
F_inv(A,B,(deg+1)>>1);
int L=0,len=1;
while(len<(deg<<1))len<<=1,++L;
for(int i=0; i<deg; ++i)D[i]=A[i];
for(int i=deg; i<len; ++i)D[i]=0;
calc(L);
// printf("%d %d %d\n\n",len,deg,L);
NTT(B,len,1);
NTT(D,len,1);
// for(int i=0; i<len; ++i)printf("%lld ?",B[i]);
// puts("");
for(int i=0; i<len; ++i)B[i]=1ll*(2-1ll*B[i]*D[i]%P+P)%P*B[i]%P;
// for(int i=0; i<len; ++i)printf("%lld ?",B[i]);
// puts("");
// for(int i=0; i<len; ++i)printf("%lld ?",D[i]);
// puts("\n");
NTT(B,len,-1);
for(int i=deg; i<len; ++i)B[i]=0;
return ;
}
inline void poly_mod(ll *A) {
int mi=(k<<1);
while(A[--mi]==0);
if(mi<k)return ;
for(int i=0; i<(len); ++i)rf[i]=0;
for(int i=0; i<=mi; ++i)rf[i]=A[i];
reverse(rf,rf+mi+1);
for(int i=mi-k+1; i<=mi; ++i)rf[i]=0;
NTT(rf,(len),1);
for(int i=0; i<(len); ++i)q[i]=(rf[i]*irg[i])%P;
NTT(q,(len),-1);
for(int i=mi-k+1; i<=(len); ++i)q[i]=0;
reverse(q,q+mi-k+1);
NTT(q,(len),1);
for(int i=0; i<(len); ++i)(q[i]=q[i]*sg[i]%P);
NTT(q,(len),-1);
for(int i=0; i<k; ++i)A[i]=(A[i]+P-q[i])%P;
for(int i=k; i<=mi; ++i)A[i]=0;
}
int main() {
scanf("%d%d",&n,&k);
len=1;
while(len<(k<<1))len<<=1,++L;
for(int i=1; i<=k; ++i) {
scanf("%lld",&xs[i]);
xs[i]=xs[i]<0?xs[i]+P:xs[i];//��...
}
for(int i=0; i<k; ++i) {
scanf("%lld",&st[i]);
st[i]=st[i]<0?st[i]+P:st[i];//��...
}
for(int i=1; i<=k; ++i)sg[k-i]=xs[i];//��,f�ǵ���ϵ��,�����ͷ�ת��-��
sg[k]=1;
for(int i=0; i<=k; ++i)ret[i]=sg[i];
for(int i=0; i<=k; ++i)rf[i]=sg[i];
reverse(rf,rf+k+1);
F_inv(rf,irg,len);
// for(int i=0; i<len; ++i)printf("%lld?\n",irg[i]);
for(int i=0; i<=k; ++i)rf[i]=0;
// printf("%d?\n",len);
len<<=1;
++L;
NTT(sg,len,1);
NTT(irg,len,1);
// for(int i=0; i<len; ++i)printf("%lld %lld\n",sg[i],irg[i]);
a[1]=1;
res[0]=1;
calc(L+1);
//for(int i=0; i<len; ++i)printf("%d? ",R[i]);
// printf("#%d\n",len);
while(n) {
if(n&1) {
// puts("Yes");
NTT(res,len,1);
NTT(a,len,1);
for(int i=0; i<len; ++i)res[i]=(res[i]*a[i])%P;
NTT(res,len,-1);
NTT(a,len,-1);
// for(int i=0; i<len; ++i)printf("%lld$ ",a[i]);
// puts("");
poly_mod(res);
}
NTT(a,len,1);
for(int i=0; i<len; ++i)a[i]=a[i]*a[i]%P;
NTT(a,len,-1);
// for(int i=0; i<len; ++i)printf("%lld? ",a[i]);
// puts("");
poly_mod(a);
n>>=1;
}
for(int i=0; i<k; ++i)(ans+=res[i]*st[i])%=P;//,printf("%lld %lld\n",st[i],res[i]);
printf("%lld\n",ans);
return 0;
}
因为2合1最长,所以这个要放好图
最后罕见的吐槽一下katex的好用并且我不需要typora2333