搞半天是NTT写挂了,变量写错了。。

#include <bits/stdc++.h>
/*
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
*/
using namespace std;

const double eps = 1e-10;
const double pi = 3.1415926535897932384626433832795;
const double eln = 2.718281828459045235360287471352;
#define int long long
#define f(i, a, b) for (int i = a; i <= b; i++)
#define scan(x) scanf("%d", &x)
#define mp make_pair
#define pb push_back
#define lowbit(x) (x&(-x))

#define fi first
#define se second
#define SZ(x) int((x).size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define summ(a) (accumulate(all(a), 0ll))

typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;

using ll=long long;
const int mod=998244353,g=3,gi=332748118,N=1e5+9;
int tt,n,x,k,fac[N],inv[N],a[N],dp[N];

int rev[N],a0[N],a1[N],bit,lim;
int powmod(int a,int b){
    int res=1;
    for(;b>0;b>>=1,a=1ll*a*a%mod)if(b&1)res=1ll*res*a%mod;
    return res;
}

void init(int n){
    fac[0]=1,inv[0]=1;
    for(int i=1;i<=n;++i)fac[i]=1ll*fac[i-1]*i%mod;
    inv[n]=powmod(fac[n],mod-2);
    for(int i=n-1;i;--i)inv[i]=1ll*inv[i+1]*(i+1)%mod;
}

void NTT(int *aa,int type){
    for(int i=0;i<lim;++i)if(i<rev[i])swap(aa[i],aa[rev[i]]);
    for (int k = 1; k < lim; k <<= 1){
        int wn = powmod((type == 1 ? g : gi), (mod - 1) / (k << 1));
        for (int i = 0; i < lim; i += (k << 1)){
            int w=1;
            for(int j=0;j<k;++j,w=1ll*w*wn%mod){
                ll x=aa[i+j],y=1ll*w * aa[i + j + k]%mod;
                aa[i + j] = (x + y)%mod;
                aa[i + j + k] = ((x - y)%mod+mod)%mod;
            }
        }
    }
    if (type == 1)
        return;
    int invn = powmod(lim,mod-2);
    for (int i = 0; i < lim; ++i)
        aa[i] = 1ll*aa[i] * invn%mod;
}

void solve(){
    cin>>n>>k>>x;
    for(int i=0;i<=k;++i)dp[i]=0;
    lim=1,bit=0;
    while(lim<=2*k)lim<<=1,bit+=1;
    for (int i = 0; i < lim; ++i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    for(int i=1,m;i<=n;++i){
        m=1;
        cin>>a[i];
        for(int j=0;j<=k;++j){
            a0[j]=1ll*dp[j]*inv[j]%mod;
            a1[j]=1ll*m*inv[j]%mod;
            m=1ll*m*a[i]%mod;
        }
        for(int j=k+1;j<=lim;++j)a0[j]=a1[j]=0;
        NTT(a0,1),NTT(a1,1);
        for(int i=0;i<lim;++i)a0[i]=1ll*a0[i]*a1[i]%mod;
        NTT(a0,0);
        m=1;
        for(int j=0;j<=k;++j){
            dp[j]=(1ll*dp[j]+1ll*m*x%mod+1ll*a0[j]*x%mod*fac[j]%mod)%mod;
            m=1ll*m*a[i]%mod;
        }
    }
    cout<<dp[k]<<"\n";
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    init(1e4+100);
    cin>>tt;
    f(sb,1,tt)solve();
    return 0;
}
此文章已被阅读次数:正在加载...更新于