zcmimi's blog
查看原题

点击跳转

#include<bits/stdc++.h>
const int N=1000011,P=1000000007,G=5,Gi=400000003;
int fac[N],inv[N],ifac[N],B[N],C[N];
int pw(int x,int b){
    int res=1;
    while(b){
        if(b&1)res=1ll*res*x%P;
        b>>=1;x=1ll*x*x%P;
    }
    return res;
}
void mod(int&x){if(x>=P)x-=P;}
int L,r[N];
void getL(int n){
    for(L=1;L<n;L<<=1);
    for(int i=0;i<L;++i)
        r[i]=(r[i>>1]>>1)|((i&1)?(L>>1):0);
}
void swap(int&x,int&y){x^=y,y^=x,x^=y;}
void ntt(int*A,int typ){
    for(int i=0;i<L;++i)
        if(i<r[i])swap(A[i],A[r[i]]);
    for(int len=1;len<L;len<<=1){
        int Wn=pw(~typ?G:Gi,(P-1)/(len<<1));
        for(int i=0;i<L;i+=len<<1){
            int w=1;
            for(int k=0;k<len;++k){
                int t=1ll*w*A[i+k+len]%P;
                mod(A[i+k+len]=A[i+k]-t+P);
                mod(A[i+k]+=t);
                w=1ll*w*Wn%P;
            }
        }
    }
    if(~typ)return;
    for(int i=0,iL=pw(L,P-2);i<L;++i)
        A[i]=1ll*A[i]*iL%P;
}
void inv(int*A,int*B,int n){
    if(n==1)return B[0]=pw(A[0],P-2),void();
    inv(A,B,(n+1)>>1);
    getL(n<<1);
    for(int i=0;i<n;++i)C[i]=A[i];
    for(int i=n;i<L;++i)C[i]=0;
    ntt(C,1),ntt(B,1);
    for(int i=0;i<L;++i)
        B[i]=1ll*B[i]*(2-1ll*C[i]*B[i]%P+P)%P;
    ntt(B,-1);
    for(int i=n;i<L;++i)B[i]=0;
}
void prep(int n){
    fac[0]=fac[1]=inv[0]=inv[1]=ifac[0]=ifac[1]=B[0]=1;
    for(int i=2;i<=n;++i)
        fac[i]=1ll*fac[i-1]*i%P,
        inv[i]=1ll*inv[P%i]*(P-P/i)%P,
        ifac[i]=1ll*ifac[i-1]*inv[i]%P;

    inv(ifac+1,B,L);
    for(int i=2;i<=n;++i)B[i]=1ll*B[i]*fac[i]%P;
}
int S(int n,int k){
    int ans=0;
    for(int i=0;i<=k;++i)
        mod(ans+=1ll*fac[k]*ifac[i]%P*ifac[k+1-i]%P*B[i]%P*pw(n,k+1-i)%P);
    return ans;
}
int main(){
    int n,k;
    scanf("%d%d",&n,&k);
    getL(k+1),prep(L);
    printf("%d\n",S(n+1,k));        
}
LG 3711 仓鼠的数学题
comment评论
Search
search