题目大意
一行有\(n\)个球,现在将这些球分成\(k\) 组,每组可以有一个球或相邻两个球。一个球只能在至多一个组中(可以不在任何组中)。求对于\(1\leq k\leq m\)的所有\(k\)分别有多少种分组方法。
答案对\(998244353\)取模。
\(n\leq {10}^9,m<2^{19}\)
题解
因为\(k>n\)的项都是\(0\),所以我们钦定\(m\leq n\)
考虑DP。
记\(f_{i,j}\)为前\(i\)个球分为\(j\)组的方案数。
\[ f_{i,j}=f_{i-1,j}+f_{i-1,j-1}+f_{i-2,j-1} \] 直接做是\(O(nm)\)的。如果把\(f_i\)看成一个多项式,即
\[ F_i(x)=\sum_{j\geq 0}f_{i,j}x^j \] 那么转移就变成了\[ F_i(x)=(1+x)F_{i-1}(x)+xF_{i-2}(x) \] 这是一个常系数齐次线性递推,用FFT优化可以做到\(O(m\log m\log n)\)考虑怎么求一个常系数齐次线性递推关系的通项公式。
先求出这个转移矩阵的特征多项式:
\[ \lambda^2-(1+x)\lambda-x \] 特征值为\[ \begin{align} \lambda_1&=\frac{1+x+\sqrt{x^2+6x+1}}{2}\\ \lambda_2&=\frac{1+x-\sqrt{x^2+6x+1}}{2} \end{align} \] 我们钦定\(F_{-1}(x)=0\),设\[ F_i(x)=c_1{\lambda_1}^{i+1}+c_2{\lambda_2}^{i+1} \] 带入\(F_{-1}(x),F_0(x)\)得\[ \begin{cases} c_1&+c2&=0\\ c_2\lambda_1&+c_2\lambda_2&=1 \end{cases} \] 解得\[ \begin{cases} c_1&=\frac{1}{\lambda_1-\lambda_2}\\ c_2&=\frac{1}{\lambda_2-\lambda_1} \end{cases} \] 于是\[ F_i(x)=\frac{ {\lambda_1}^{i+1}-{\lambda_2}^{i+1}}{\lambda_1-\lambda_2} \] 直接用多项式开根求出\(\lambda_1,\lambda_2\),然后用多项式\(\ln \exp\)求出\(F_n(x)\)时间复杂度:\(O(m\log m)\)
小优化:因为\([x^0]\lambda_2=0\),所以\(\lambda_2\)的前\(n+1\)项都是\(0\)。因为\(m\leq n\),所以
\[ {\lambda_2}^{n+1}\equiv 0 \pmod {x^{m+1}} \] 所以我们不用计算\({\lambda_2}^{n+1}\)了。代码
#include#include #include using namespace std;typedef long long ll;const int maxn=300000;const ll p=998244353;const ll g=3;const ll inv2=(p+1)/2;ll fp(ll a,ll b){ ll s=1; for(;b;b>>=1,a=a*a%p) if(b&1) s=s*a%p; return s;}int rev[maxn];ll w1[maxn];ll w2[maxn];void ntt(ll *a,int n,int t){ int i,j,k; ll u,v,w,wn; for(i=1;i >1]>>1)|(i&1?n>>1:0); if(rev[i]>i) swap(a[i],a[rev[i]]); } for(i=2;i<=n;i<<=1) { wn=(t==1?w1[i]:w2[i]); for(j=0;j >1); static ll a1[maxn],a2[maxn]; int i; for(i=0;i >1;i++) a2[i]=b[i]; for(;i <<1;i++) a2[i]=0; ntt(a1,n<<1,1); ntt(a2,n<<1,1); for(i=0;i <<1;i++) a1[i]=a2[i]*((2-a1[i]*a2[i])%p)%p; ntt(a1,n<<1,-1); for(i=0;i >1); static ll a1[maxn],a2[maxn],a3[maxn]; int i; for(i=0;i >1;i++) a1[i]=b[i]; for(;i <<1;i++) a1[i]=0; for(i=0;i >1); int i; for(i=n>>1;i >1;i++) { a2[i]=b[i]; a3[i]=a[i+(n>>1)]-a1[i+(n>>1)]; } for(;i >1;i++) b[i+(n>>1)]=a2[i];}ll a[maxn];ll b[maxn];ll c[maxn];ll d[maxn];ll e[maxn];int n,m;int k;void solve(){ k=1; while(k<=m) k<<=1; int i; for(i=2;i<=k<<1;i++) { w1[i]=fp(g,(p-1)/i); w2[i]=fp(w1[i],p-2); } d[0]=1; d[1]=6; d[2]=1; getsqrt(d,c,k); for(i=0;i n) { x=m-n; m=n; } solve(); while(x--) printf("0 "); return 0;}