#include <cstdio>
#include<assert.h>

using namespace std;

using ll = long long;

const int mod = 998244353;

template <int MOD, int RT>
struct mint
{
    static const int mod = MOD;
    static constexpr mint rt() { return RT; } // primitive root for FFT
    int v;
    explicit operator int() const { return v; } // explicit -> don't silently convert to int
    mint() : v(0) {}
    mint(ll _v)
    {
        v = int((-MOD < _v && _v < MOD) ? _v : _v % MOD);
        if (v < 0)
            v += MOD;
    }
    bool operator==(const mint &o) const
    {
        return v == o.v;
    }
    friend bool operator!=(const mint &a, const mint &b)
    {
        return !(a == b);
    }
    friend bool operator<(const mint &a, const mint &b)
    {
        return a.v < b.v;
    }

    mint &operator+=(const mint &o)
    {
        if ((v += o.v) >= MOD)
            v -= MOD;
        return *this;
    }
    mint &operator-=(const mint &o)
    {
        if ((v -= o.v) < 0)
            v += MOD;
        return *this;
    }
    mint &operator*=(const mint &o)
    {
        v = int((ll)v * o.v % MOD);
        return *this;
    }
    mint &operator/=(const mint &o) { return (*this) *= inv(o); }
    friend mint pow(mint a, ll p)
    {
        mint ans = 1;
        assert(p >= 0);
        for (; p; p /= 2, a *= a)
            if (p & 1)
                ans *= a;
        return ans;
    }
    friend mint inv(const mint &a)
    {
        assert(a.v != 0);
        return pow(a, MOD - 2);
    }

    mint operator-() const { return mint(-v); }
    mint &operator++() { return *this += 1; }
    mint &operator--() { return *this -= 1; }
    friend mint operator+(mint a, const mint &b) { return a += b; }
    friend mint operator-(mint a, const mint &b) { return a -= b; }
    friend mint operator*(mint a, const mint &b) { return a *= b; }
    friend mint operator/(mint a, const mint &b) { return a /= b; }
};

using mi = mint<mod, 3>; // 5 is primitive root for both common mods

mi s[(int)5e3 + 9][(int)5e3 + 9] = {1};
ll n, m, k;

ll cal(ll x, ll t)
{
    ll res = 1;
    for (ll i = x; i>=x-t+1; --i)
        res = i * res % mod;
    return res;
}

int main()
{
    scanf("%lld%lld%lld", &n, &m, &k);
    for(int i=1;i<=k;++i)for(int j=1;j<=k;++j)s[i][j]=s[i-1][j-1]+s[i-1][j]*j;
    mi ans=0,invm=inv(mi(m));
    for(int i=0;i<=k;++i)ans+=s[k][i]*cal(n,i)*pow(invm,i);
    printf("%d",ans.v);
}
此文章已被阅读次数:正在加载...更新于