卡时间过了 ^_^

卡时

#include <bits/stdc++.h>
using namespace std;

using ll=long long;

const int mod=998244353;
const int g=3,gi=332748118,NR=1<<19,maxn=125009;

int n=1,m,a[NR],b[NR],tot,rev[NR],bit,ans[10][10][maxn],ls,lt,sig;
bitset<300> vis;
char s[maxn],t[maxn];
vector<char> sigma;

struct dsu{
    vector<int> fa,sz;
    dsu(int x){
        fa.resize(x+1),sz.resize(x+1);
        for(int i=1;i<=x;++i)fa[i]=i,sz[i]=1;
    }
    int fnd(int x){return x==fa[x]?x:fa[x]=fnd(fa[x]);}
    bool merge(int x,int y){
        int oa=fnd(x),ob=fnd(y);
        if(oa!=ob){
            fa[ob]=oa,sz[oa]+=sz[ob];
            return true;
        }
        return false;
    }
};

int powmod(int a,int b){
    int res=1;
    for(;b>0;b>>=1,a=1ll*a*a%mod)if(b&1)res=1ll*res*a%mod;
    return res;
}

void NTT(int *a,int type){
    for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int i=1;i<n;i<<=1){
        int gn=powmod(type?g:gi,(mod-1)/(i<<1));
        for(int j=0;j<n;j+=(i<<1)){
            ll g0=1;
            for(int k=0;k<i;++k,g0=1ll*g0*gn%mod){
                int x=a[j+k],y=g0*a[i+j+k]%mod;
                a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod;
            }
        }
    }
    if(type==1)return ;
    int invn=powmod(n,mod-2);
    for(int i=0;i<n;++i)a[i]=1ll*a[i]*invn%mod;
}

void solve(char aa,char bb){
    //printf("%c--%c\n",aa,bb);
    for(int i=0;i<n;++i)a[i]=b[i]=0;
    for(int i=1;i<=ls;++i)a[i-1]=(s[i]==aa);
    for(int j=1;j<=lt;++j)b[lt-j]=(t[j]==bb);
    //for(int i=0;i<ls;++i)cout<<a[i]<<" \n"[i==ls-1];\
    for(int i=0;i<lt;++i)cout<<b[i]<<" \n"[i==lt-1];
    NTT(a,1),NTT(b,1);
    for(int i=0;i<n;++i)a[i]=1ll*b[i]*a[i]%mod;
    NTT(a,0);
    for(int i=lt-1;i<ls;++i)ans[aa-'a'][bb-'a'][i]=a[i]!=0;
}

int cal(int x){
    //cerr<<x<<"---\n";
    dsu d(30);
    int res=0;
    for(int i=0;i<sig;++i)for(int j=0;j<sig;++j)if(ans[sigma[i]-'a'][sigma[j]-'a'][x])res+=d.merge(sigma[i]-'a',sigma[j]-'a');
    return res;
}

int main(){
    scanf("%s%s",s+1,t+1);
    ls=strlen(s+1),lt=strlen(t+1);
    for(int i=1;i<=ls;++i)if(!vis[s[i]])sigma.push_back(s[i]),vis[s[i]]=1;
    for(int i=1;i<=lt;++i)if(!vis[t[i]])sigma.push_back(t[i]),vis[t[i]]=1;
    sig=sigma.size();
    while(n<ls+lt)n<<=1,bit+=1;
    for(int i=0;i<n;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    for(int i=0;i<sig;++i){
        for(int j=0;j<sig;++j){
            if(i==j)continue;
            solve(sigma[i],sigma[j]);
        }
    }
    for(int i=lt-1;i<ls;++i)printf("%d%c",cal(i)," \n"[i==ls-1]);
}
此文章已被阅读次数:正在加载...更新于