这次算明白数了。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 998244353;
char q[1000009];
ll ql, ans, k, kc, a, xsum, s, pre[1000005][62], hf, s1[1000005], s2[1000005], rp[1000005], to[1000005],lst[1000005],inv2,glst[100];
ll suf[63][63],res[63][63];
int powmod(int a,int b){
    int res=1;
    for(;b>0;b>>=1,a=1ll*a*a%mod)if(b&1)res=1ll*res*a%mod;
    return res;
}

signed main()
{
    inv2=powmod(2,mod-2);
    scanf("%s", q + 1);
    ql = strlen(q + 1);

    for (int i = 1; i <= ql; i++)
    {
        if (q[i] <= '9')
            q[i] -= '0';
        else if (q[i] <= 'Z')
            q[i] = q[i] - 'A' + 10;
        else
            q[i] = q[i] - 'a' + 36;
    }
    for(int i=1;i<=ql;++i){
        lst[glst[q[i]]]=i,glst[q[i]]=i;
    }
    //for (int i = 1; i <= ql; i++)cout<<(int)q[i]<<" \n"[i==ql];
    for (int i = 1; i <= ql; i++)
    {
        pre[i][q[i]]++;
        for (int j = 0; j < 62; j++)
        {
            pre[i][j] += pre[i - 1][j];
            pre[i][j] %= mod;
            to[i] = (to[i] + pre[i][j]) % mod;
            rp[i] = (rp[i] + (pre[i][j] * (pre[i][j] - 1ll ) %mod *inv2 % mod)+mod) % mod;
        }
    }
    // for (int i=1;i<=ql;i++) cout<<(int)q[i]<<" ";
    ans = 0;
    for(int bb=ql;bb;--bb){
        int i=q[bb];
        for(int j=0;j<62;++j){
            if(j==i)continue;
            //if(i!=38||j!=39)continue;
            ll hf = to[bb]-pre[bb][i]-pre[bb][j], a = rp[bb]-pre[bb][i]*(pre[bb][i]-1)/2%mod-pre[bb][j]*(pre[bb][j]-1)/2%mod;
            hf=(hf%mod+mod)%mod;
            a=(a%mod+mod)%mod;
            //cerr<<bb<<"++"<<hf<<"::"<<a<<"::"<<(pre[lst[bb]][j]-pre[bb][j]+mod)%mod<<"---"<<suf[i][j]<<"?"<<ans<<"\n";
            if(lst[bb]){
                ll sz=(pre[lst[bb]][j]-pre[bb][j]+mod)%mod;
                res[i][j]=(res[i][j]+sz*suf[i][j]%mod)%mod;
                ans=(ans+(hf * (hf - 1) %mod*inv2%mod - a+mod)%mod*res[i][j]%mod)%mod;
                //ans = (ans+(hf * (hf - 1) %mod*inv2%mod - a+mod)%mod * res % mod)%mod;
            }
            suf[i][j]=(suf[i][j]+pre[ql][j]-pre[bb][j]+mod)%mod;
        }
    }
    /*for (int i = 0; i < 62; i++)
        for (int j = 0; j < 62; j++)
        {
            if (i == j)
                continue;
            if(pre[ql][i]<2||pre[ql][j]<2)continue;
            //if(i!=15||j!=36)continue;
            vector<pair<ll, ll>> v;
            ll cn = 0, lst = -1;
            for (int k = 1; k <= ql; ++k)
            {
                if (q[k] == j)
                    cn += 1;
                if (q[k] == i && lst != -1)
                    v.push_back({cn, lst});
                if (q[k] == i)
                    lst = k, cn = 0;
                // cerr<<lst<<"::"<<cn<<"\n";
            }
            if (lst != -1)
                v.push_back({cn, lst});

            ll sz = v.size();
            if (sz <= 1)
                continue;

            // for(auto [aa,bb]:v)cout<<aa<<"???"<<bb<<"\n";
            for (int i = 1; i <= sz; ++i)
                s1[i] = v[i - 1].first%mod;
            for (int i = 1; i <= sz; ++i)
                s1[i] = (s1[i] +s1[i - 1])%mod;
            for (int k = 1; k <= sz; ++k)
                s2[k] = k * v[k - 1].first%mod;
            for (int k = 1; k <= sz; ++k)
                s2[k] =(s2[k]+ s2[k - 1])%mod;
            ll sy = ((s2[sz] - s1[sz])%mod+mod)%mod, res = 0;
            // cerr<<s1[sz]<<"::"<<s2[sz]<<"??\n";
            for (int i = 1; i < sz; ++i)
                res =(res+ s1[i] * (s1[sz] - s1[i])%mod+mod)%mod;
            // cerr<<res<<"?\n";
            for (int ii = 1; ii < sz; ++ii)
            {
                auto &&[aa, bb] = v[ii - 1];
                auto &&[cc, dd] = v[ii];
                ll hf = to[bb]-pre[bb][i]-pre[bb][j], a = rp[bb]-pre[bb][i]*(pre[bb][i]-1)/2%mod-pre[bb][j]*(pre[bb][j]-1)/2%mod;
                hf=(hf%mod+mod)%mod;
                a=(a%mod+mod)%mod;
                for(int k=0;k<62;++k){
                    if(k!=i&&k!=j){
                        hf=(hf+pre[bb][k])%mod;
                        a=(a-(pre[bb][k]*(pre[bb][k]-1)/2%mod)+mod)%mod;
                    }
                }
                //cerr<<hf<<"\n";
                ans = (ans+(hf * (hf - 1) %mod*inv2%mod - a+mod)%mod * res % mod)%mod;
                res = (res-aa * sy%mod+mod)%mod;
                // cerr<<aa<<"::"<<sy<<"??\n";
                sy =(sy- (s1[sz] - s1[ii])%mod+mod)%mod;
                // cerr<<res<<"::"<<ii<<"---"<<sy<<"\n";
            }
        }*/
    cout << ans;
    return 0;
}
此文章已被阅读次数:正在加载...更新于