头晕。。。
所以即求
其中
#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
using ll = long long;
const int N = 1 << 21;
const int P = 998244353;
const int G = 3, Gi = 332748118;
const int mod = 998244353;
int qmi(int a, int b)
{
int v = 1;
while (b)
{
if (b & 1)
v = 1ll * v * a % P;
a = 1ll * a * a % P;
b >>= 1;
}
return v;
}
int rev[N];
void NTT(int *a, int n, int inv)
{
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int mid = 1; mid < n; mid <<= 1)
{
int Wn = qmi((inv == 1) ? G : Gi, (P - 1) / (mid << 1));
for (int i = 0; i < n; i += (mid << 1))
{
int w = 1;
for (int j = 0; j < mid; j++, w = 1ll * w * Wn % P)
{
int x = a[i + j], y = 1ll * w * a[i + j + mid] % P;
a[i + j] = (x + y) % P;
a[i + j + mid] = (x - y + P) % P;
}
}
}
if (inv == -1)
{
int invn = qmi(n, P - 2);
for (int i = 0; i < n; i++)
a[i] = 1ll * a[i] * invn % P;
}
}
//=======================================================
int fact[N], infact[N];
void init()
{
fact[0] = 1;
for (int i = 1; i <= 100000; i++)
fact[i] = 1ll * i * fact[i - 1] % mod;
infact[100000] = qmi(fact[100000], mod - 2);
for (int i = 99999; i; i--)
infact[i] = 1ll * infact[i + 1] * (i + 1) % mod;
}
int n, m, A[N], B[N], C[N];
int Mul(int *a, int *b, int n, int m, int *ans)
{
int bit = 0, num = 1;
while (num < n + m + 1)
num <<= 1, bit++;
for (int i = 0; i < num; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
for (int i = 0; i <= n; i++)
A[i] = a[i];
for (int i = 0; i <= m; i++)
B[i] = b[i];
for (int i = n + 1; i < num; i++)
A[i] = 0;
for (int i = m + 1; i < num; i++)
B[i] = 0;
NTT(A, num, 1);
NTT(B, num, 1);
for (int i = 0; i < num; i++)
C[i] = 1ll * A[i] * B[i] % mod;
NTT(C, num, -1);
for (int i = 0; i <= n + m; i++)
ans[i] = C[i];
return n + m;
}
int pool[N << 1], tot;
struct Node
{
int *p, n;
void init(int x)
{
this->n = 1;
p = pool + tot;
for (int i = 0; i <= n; i++)
p[i] = 0;
p[0] = 1;
p[1] = x;
tot += n + 1;
}
void mul(const Node &o)
{
n = Mul(p, o.p, n, o.n, p);
}
};
Node solve(int l, int r)
{
Node ans;
if (l == r)
{
int x;
cin >> x;
ans.init(x);
return ans;
}
int mid = (l + r) >> 1;
ans = solve(l, mid);
ans.mul(solve(mid + 1, r));
return ans;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
init();
cin >> n >> m;
ll invk = qmi(m, mod - 2);
invk = qmi(invk, m / 2);
Node res = solve(1, n);
printf("%lld", 1ll * res.p[m] * fact[m] % mod * infact[n] % mod * fact[n - m] % mod * invk % mod);
return 0;
}