暑假多校弄出来的,今晚没弄出来,寄!

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int NR = 1 << 22;
const int G = 3, Gi = 332748118;
const int mod = 998244353, MOD = mod;
#define f(i, a, b) for (int i = a; i <= b; i++)
#define LL long long
#define IN freopen("in.txt", "r", stdin)
#define OUT freopen("out.txt", "w", stdout)
#define scan(x) scanf("%d", &x)
#define mp make_pair
#define pb push_back
#define sqr(x) (x) * (x)
#define pr1(x) printf("Case %d: ", x)
#define pn1(x) printf("Case %d:\n", x)
#define pr2(x) printf("Case #%d: ", x)
#define pn2(x) printf("Case #%d:\n", x)
#define lowbit(x) (x & (-x))

#define fi first
#define se second
#define sz(x) int((x).size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define summ(a) (accumulate(all(a), 0ll))
#define SZ(x) ((int)(x).size())
template <class T>
inline T read(const T sample)
{
    T x = 0;
    int f = 1;
    char s;
    while ((s = getchar()) > '9' || s < '0')
        if (s == '-')
            f = -1;
    while (s >= '0' && s <= '9')
        x = (x << 1) + (x << 3) + (s ^ 48), s = getchar();
    return x * f;
}

ll qpow(ll a, ll b)
{
    ll res = 1;
    a %= mod;
    assert(b >= 0);
    for (; b; b >>= 1)
    {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }

template <int MOD, int RT>
struct mint
{
    static const int mod = MOD;
    static constexpr mint rt() { return RT; } // primitive root for FFT
    int v;
    explicit operator int() const { return v; } // explicit -> don't silently convert to int
    mint() : v(0) {}
    mint(ll _v)
    {
        v = int((-MOD < _v && _v < MOD) ? _v : _v % MOD);
        if (v < 0)
            v += MOD;
    }
    bool operator==(const mint &o) const
    {
        return v == o.v;
    }
    friend bool operator!=(const mint &a, const mint &b)
    {
        return !(a == b);
    }
    friend bool operator<(const mint &a, const mint &b)
    {
        return a.v < b.v;
    }

    mint &operator+=(const mint &o)
    {
        if ((v += o.v) >= MOD)
            v -= MOD;
        return *this;
    }
    mint &operator-=(const mint &o)
    {
        if ((v -= o.v) < 0)
            v += MOD;
        return *this;
    }
    mint &operator*=(const mint &o)
    {
        v = int((ll)v * o.v % MOD);
        return *this;
    }
    mint &operator/=(const mint &o) { return (*this) *= inv(o); }
    friend mint pow(mint a, ll p)
    {
        mint ans = 1;
        assert(p >= 0);
        for (; p; p /= 2, a *= a)
            if (p & 1)
                ans *= a;
        return ans;
    }
    friend mint inv(const mint &a)
    {
        assert(a.v != 0);
        return pow(a, MOD - 2);
    }

    mint operator-() const { return mint(-v); }
    mint &operator++() { return *this += 1; }
    mint &operator--() { return *this -= 1; }
    friend mint operator+(mint a, const mint &b) { return a += b; }
    friend mint operator-(mint a, const mint &b) { return a -= b; }
    friend mint operator*(mint a, const mint &b) { return a *= b; }
    friend mint operator/(mint a, const mint &b) { return a /= b; }
};

using mi = mint<mod, 3>;
namespace simp
{
    vector<mi> fac, ifac, invn;
    void check(int x)
    {
        if (fac.empty())
        {
            fac = {mi(1), mi(1)};
            ifac = {mi(1), mi(1)};
            invn = {mi(0), mi(1)};
        }
        while (SZ(fac) <= x)
        {
            int n = SZ(fac), m = SZ(fac) * 2;
            fac.resize(m);
            ifac.resize(m);
            invn.resize(m);
            for (int i = n; i < m; i++)
            {
                fac[i] = fac[i - 1] * mi(i);
                invn[i] = mi(MOD - MOD / i) * invn[MOD % i];
                ifac[i] = ifac[i - 1] * invn[i];
            }
        }
    }
    mi gfac(int x)
    {
        check(x);
        return fac[x];
    }
    mi ginv(int x)
    {
        check(x);
        return invn[x];
    }
    mi gifac(int x)
    {
        check(x);
        return ifac[x];
    }
    mi binom(int n, int m)
    {
        if (m < 0 || m > n)
            return mi(0);
        return gfac(n) * gifac(m) * gifac(n - m);
    }
}

int rev[NR];

void NTT(mi *a, int n, int type)
{
    for (int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(a[i], a[rev[i]]);
    for (int k = 1; k < n; k <<= 1)
    {
        mi wn = pow((type == 1 ? mi(G) : mi(Gi)), (mod - 1) / (k << 1));
        for (int i = 0; i < n; i += (k << 1))
        {
            mi w = mi(1);
            for (int j = 0; j < k; ++j, w = w * wn)
            {
                mi x = a[i + j], y = w * a[i + j + k];
                a[i + j] = x + y;
                a[i + j + k] = x - y;
            }
        }
    }
    if (type == 1)
        return;
    mi invn = inv(mi(n));
    for (int i = 0; i < n; ++i)
        a[i] = a[i] * invn;
}

//=======================================================
int n, m, q;
mi A[NR], B[NR], C[NR];

int Mul(mi *a, mi *b, int n, int m, mi *ans)
{
    int bit = 0, num = 1;
    while (num < n + m + 1)
        num <<= 1, ++bit;
    for (int i = 0; i < num; ++i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    for (int i = 0; i <= n; ++i)
        A[i] = a[i];
    for (int i = 0; i <= m; ++i)
        B[i] = b[i];
    for (int i = n + 1; i < num; ++i)
        A[i] = 0;
    for (int i = m + 1; i < num; ++i)
        B[i] = 0;
    NTT(A, num, 1);
    NTT(B, num, 1);
    for (int i = 0; i < num; ++i)
        C[i] = A[i] * B[i];
    NTT(C, num, 0);
    for (int i = 0; i <= n + m; ++i)
        ans[i] = C[i];
    return n + m;
}

struct Node
{
    static mi PP[NR << 1];
    static int tot;

    mi *p;
    int n;
    void init(int x)
    {
        this->n = 1;
        p = PP + tot;
        p[0] = mi(1);
        p[1] = mi(x);
        tot += n + 1;
    }
    void mul(const Node &o)
    {
        n = Mul(p, o.p, n, o.n, p);
    }
};
mi Node::PP[NR << 1] = {mi(0)};
int Node::tot = 0;

int a[(int)5e6 + 9], fa[(int)5e6 + 9], sz[(int)5e6 + 9], cnt[(int)5e6 + 9], tol = 0;
template <class T>
T cdq(int l, int r, T aa)
{
    T ans;
    if (l == r)
    {
        ans.init(a[l]);
        return ans;
    }
    int mid = (l + r) >> 1;
    ans = cdq(l, mid, ans);
    ans.mul(cdq(mid + 1, r, ans));
    return ans;
}

int fnd(int x) { return fa[x] == x ? x : fa[x] = fnd(fa[x]); }

namespace sta
{
    const int maxn = 1e6 + 9;
    ll pr[maxn], mu[maxn], sum_mu[maxn], cnt;
    std::bitset<(int)1e6 + 9> vis;
    void init(int x = 1e6)
    {
        mu[1] = 1, vis[1] = 1;
        for (ll i = 2; i <= x; ++i)
        {
            if (!vis[i])
                pr[++cnt] = i, mu[i] = -1;
            for (int j = 1; j <= cnt && i * pr[j] <= x; ++j)
            {
                vis[pr[j] * i] = 1;
                if (i % pr[j] == 0)
                    break;
                mu[i * pr[j]] = -mu[i];
            }
        }
        for (int i = 1; i <= x; ++i)
            sum_mu[i] = sum_mu[i - 1] + mu[i];
    }
};

ll ct,cn[(int)1e6+9],b[(int)6022],qsb;
int main()
{
    ll n;
    Node::tot = 0;
    tol = 0;
    sta::init();

    n = read(1);
    n*=2;
    std::vector<mi> fac(n + 1), invfac(n + 1);
    fac[0] = 1;
    for (int i = 1; i <= n; i++) {
        fac[i] = fac[i - 1] * i;
    }
    invfac[n] = inv(fac[n]);
    for (int i = n; i; i--) {
        invfac[i - 1] = invfac[i] * i;
    }
    f(i, 1, n) {
        b[i] = read(1);
        cn[b[i]]+=1;
        if(!sta::vis[b[i]]&&cn[b[i]]==1){
            ct+=1;
        }
    }
    if(ct<n/2){
        cout<<"0\n";
        return 0;
    }

    mi base=1;
    f(i,1,1e6){
        if(!sta::vis[i])a[++qsb]=cn[i];
        base*=invfac[cn[i]];
    }
    Node ans;
    ans = cdq(1, ct, ans);
    mi res = ans.p[n/2]*base*fac[n/2];
    cout << res.v << "\n";
}
此文章已被阅读次数:正在加载...更新于