抄的代码,实在是不会组合数学:):):)

#include <bits/stdc++.h>
/*
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
*/
using namespace std;

const double eps = 1e-10;
const double pi = 3.1415926535897932384626433832795;
const double eln = 2.718281828459045235360287471352;

#define f(i, a, b) for (int i = a; i <= b; i++)
#define scan(x) scanf("%d", &x)
#define mp make_pair
#define pb push_back
#define lowbit(x) (x&(-x))

#define fi first
#define se second
#define SZ(x) int((x).size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define summ(a) (accumulate(all(a), 0ll))

typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;

using ll=long long;

template <class T> inline T read(const T sample) {
    T x=0; int f=1; char s;
    while((s=getchar())>'9'||s<'0') if(s=='-') f=-1;
    while(s>='0'&&s<='9') x=(x<<1)+(x<<3)+(s^48),s=getchar();
    return x*f;
}

const int N=1e7+9,mod=998244353;
int n,m,k,f[N],ans;
namespace math{
    int id[N],pct,inv[N];
    inline int qpow(int n,int k){
        int res=1;
        for(;k>0;k>>=1,n=1ll*n*n%mod)if(k&1)res=1ll*n*res%mod;
        return res;
    }
    inline void fmod(int &x){x-=mod;x+=x>>31&mod;}
    void initmath(const int &n=N-1){
        id[1]=1,id[0]=0;
        for(int i=2;i<=n;++i){
            if(!inv[i])f[++pct]=i,id[i]=qpow(i,k);
            for(int j=1;i*f[j]<=n&&j<=pct;++j){
                inv[i*f[j]]=1,id[i*f[j]]=1ll*id[i]*id[f[j]]%mod;
                if(i%f[j]==0)break;
            }
        }
        inv[1]=1;
        for(int i=2;i<=n;++i)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
    }
}
using namespace math;

namespace solve1{
    void main(){
        f[0]=1;
        for(int i=1;i<=n;++i)f[i]=1ll*(mod+1-m)*f[i-1]%mod;
        for(int i=0,j=1,y=1;i<=n;++i){
            fmod(ans+=1ll*j*y%mod*f[n-i]%mod*id[i]%mod);
            j=1ll*j*(n-i)%mod*inv[i+1]%mod,y=1ll*y*m%mod;
        }
        printf("%d\n",ans);
        exit(0);
    }
}

namespace solve2{
    void main(){
        f[k]=1;
        for(int i=k-1,x=0,z=1,y=1;i>=0;--i){
            ++x,z=1ll*z*(x+n-k-1)%mod*inv[x]%mod,y=1ll*y*(mod-m)%mod;
            fmod(f[i]=1ll*(mod+1-m)*f[i+1]%mod),fmod(f[i]+=1ll*y*z%mod);
        }
        for(int i=0,j=1,y=1;i<=k;++i){
            fmod(ans+=1ll*id[i]*j%mod*y%mod*f[i]%mod);
            j=1ll*j*(n-i)%mod*inv[i+1]%mod,y=1ll*y*m%mod;
        }
        printf("%d\n",ans);
        exit(0);
    }
}
int main()
{
    n=read(1),m=qpow(read(1),mod-2),k=read(1);
    initmath(min(n,k)+3);
    if(n<=k)solve1::main();
    else solve2::main();
}
此文章已被阅读次数:正在加载...更新于