正题
题目链接:https://jzoj.net/senior/#contest/show/3017/2
题目大意
求有多少个长度为nnn的序列使得
- 都是在集合SSS中的数
- 这些数的乘积%m=x\% m=x%m=x
解题思路
设fi,jf_{i,j}fi,j表示长度为iii的序列,乘积为jjj的有多少个,显然有
fi,j∗w%m=fi−1,j(w∈S)f_{i,j*w\%m}=f_{i-1,j}(w\in S)fi,j∗w%m=fi−1,j(w∈S)
然后有
f2∗i,j=∑a∗b%m=jfi,a∗fi,bf_{2*i,j}=\sum_{a*b\%m=j}f_{i,a}*f_{i,b}f2∗i,j=a∗b%m=j∑fi,a∗fi,b
此时我们可以用矩阵乘法做到O(m3logn)O(m^3\log n)O(m3logn)
但是此复杂度显然无法胜任本题
因为mmm是质数,所以对于1∼m−11\sim m-11∼m−1都可以用一个gi%mg^i\% mgi%m表示出来,我们枚举找出一个ggg后,就有
fi,jf_{i,j}fi,j表示长度为iii的序列,乘积为gj%mg_j\%mgj%m的有多少个
那么就有f2∗i,j=∑(a+b)%m=jfi,a∗fi,bf_{2*i,j}=\sum_{(a+b)\%m=j}f_{i,a}*f_{i,b}f2∗i,j=(a+b)%m=j∑fi,a∗fi,b
这很显然是一个卷积的形式,所以我们表示出一个多项式后用NTTNTTNTT做快速幂即可。
时间复杂度:O(mlogmlogn):O(m\log m\log n):O(mlogmlogn)
codecodecode
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=5e4+10,XJQ=1004535809;
ll n,m,z,s,d[N],cnt,len,invn;
ll f[N],ans[N],tmp1[N],tmp2[N],r[N];
bool v[N];
ll power(ll x,ll b,ll p){ll ans=1;while(b){if(b&1)ans=ans*x%p;x=x*x%p;b>>=1;}return ans;
}
ll FindRoot(){ll l=m-1;for(ll i=2;i*i<=l;i++){if(l%i==0){d[++cnt]=i;while(l%i==0)l/=i;}}if(l!=1) d[++cnt]=l;l=m-1;for(ll i=2;i<=l;i++){bool flag=1;for(ll j=1;j<=cnt;j++)if(power(i,l/d[j],m)==1){flag=0;break;}if(flag) return i;}return 0;
}
void NTT(ll *x,ll op){for(ll i=0;i<len;i++)if(i<r[i])swap(x[i],x[r[i]]);for(ll p=2;p<=len;p<<=1){ll l=p>>1,tmp=power(3,(XJQ-1)/p,XJQ);if(op==-1)tmp=power(tmp,XJQ-2,XJQ);for(ll k=0;k<len;k+=p){ll buf=1;for(ll i=k;i<k+l;i++){ll tt=buf*x[l+i]%XJQ;x[l+i]=(x[i]-tt+XJQ)%XJQ;x[i]=(x[i]+tt)%XJQ;buf=buf*tmp%XJQ;}}}if(op==-1)for(int i=0;i<len;i++)x[i]=x[i]*invn%XJQ;
}
void mul(ll *a,ll *b){for(ll i=0;i<len;i++)tmp1[i]=a[i],tmp2[i]=b[i];NTT(tmp1,1);NTT(tmp2,1);for(ll i=0;i<len;i++)tmp1[i]=tmp1[i]*tmp2[i]%XJQ;NTT(tmp1,-1);for(ll i=0;i<m-1;i++)a[i]=(tmp1[i]+tmp1[i+m-1])%XJQ;
}
int main()
{ll z0;scanf("%lld%lld%lld%lld",&n,&m,&z0,&s);ll root=FindRoot();for(ll i=1;i<=s;i++){ll x;scanf("%lld",&x);v[x]=1;}for(ll i=0,x=1;i<m-1;i++,x=x*root%m){if(v[x])f[i]=1;if(x==z0)z=i;}for(len=1;len<=(m-1)<<1;len<<=1);for(ll i=0;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0);ans[0]=1;invn=power(len,XJQ-2,XJQ);while(n){if(n&1)mul(ans,f);mul(f,f);n>>=1;}printf("%lld",ans[z]);
}