头晕。。。


所以即求

其中

#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
using ll = long long;

const int N = 1 << 21;
const int P = 998244353;
const int G = 3, Gi = 332748118;
const int mod = 998244353;

int qmi(int a, int b)
{
    int v = 1;
    while (b)
    {
        if (b & 1)
            v = 1ll * v * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return v;
}
int rev[N];
void NTT(int *a, int n, int inv)
{
    for (int i = 0; i < n; i++)
        if (i < rev[i])
            swap(a[i], a[rev[i]]);

    for (int mid = 1; mid < n; mid <<= 1)
    {
        int Wn = qmi((inv == 1) ? G : Gi, (P - 1) / (mid << 1));
        for (int i = 0; i < n; i += (mid << 1))
        {
            int w = 1;
            for (int j = 0; j < mid; j++, w = 1ll * w * Wn % P)
            {
                int x = a[i + j], y = 1ll * w * a[i + j + mid] % P;
                a[i + j] = (x + y) % P;
                a[i + j + mid] = (x - y + P) % P;
            }
        }
    }
    if (inv == -1)
    {
        int invn = qmi(n, P - 2);
        for (int i = 0; i < n; i++)
            a[i] = 1ll * a[i] * invn % P;
    }
}
//=======================================================
int fact[N], infact[N];
void init()
{
    fact[0] = 1;
    for (int i = 1; i <= 100000; i++)
        fact[i] = 1ll * i * fact[i - 1] % mod;
    infact[100000] = qmi(fact[100000], mod - 2);

    for (int i = 99999; i; i--)
        infact[i] = 1ll * infact[i + 1] * (i + 1) % mod;
}
int n, m, A[N], B[N], C[N];
int Mul(int *a, int *b, int n, int m, int *ans)
{
    int bit = 0, num = 1;
    while (num < n + m + 1)
        num <<= 1, bit++;

    for (int i = 0; i < num; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    for (int i = 0; i <= n; i++)
        A[i] = a[i];
    for (int i = 0; i <= m; i++)
        B[i] = b[i];
    for (int i = n + 1; i < num; i++)
        A[i] = 0;
    for (int i = m + 1; i < num; i++)
        B[i] = 0;
    NTT(A, num, 1);
    NTT(B, num, 1);
    for (int i = 0; i < num; i++)
        C[i] = 1ll * A[i] * B[i] % mod;
    NTT(C, num, -1);
    for (int i = 0; i <= n + m; i++)
        ans[i] = C[i];
    return n + m;
}

int pool[N << 1], tot;
struct Node
{
    int *p, n;
    void init(int x)
    {
        this->n = 1;
        p = pool + tot;
        for (int i = 0; i <= n; i++)
            p[i] = 0;
        p[0] = 1;
        p[1] = x;
        tot += n + 1;
    }
    void mul(const Node &o)
    {
        n = Mul(p, o.p, n, o.n, p);
    }
};
Node solve(int l, int r)
{
    Node ans;
    if (l == r)
    {
        int x;
        cin >> x;
        ans.init(x);
        return ans;
    }
    int mid = (l + r) >> 1;
    ans = solve(l, mid);
    ans.mul(solve(mid + 1, r));
    return ans;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    init();
    cin >> n >> m;
    ll invk = qmi(m, mod - 2);
    invk = qmi(invk, m / 2);
    Node res = solve(1, n);
    printf("%lld", 1ll * res.p[m] * fact[m] % mod * infact[n] % mod * fact[n - m] % mod * invk % mod);
    return 0;
}
此文章已被阅读次数:正在加载...更新于