dsu on tree套个BIT

#include <bits/stdc++.h>
/*
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
*/
using namespace std;

const double eps = 1e-10;
const double pi = 3.1415926535897932384626433832795;
const double eln = 2.718281828459045235360287471352;

#define f(i, a, b) for (int i = a; i <= b; i++)
#define scan(x) scanf("%d", &x)
#define mp make_pair
#define pb push_back
#define lowbit(x) (x&(-x))

#define fi first
#define se second
#define SZ(x) int((x).size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define summ(a) (accumulate(all(a), 0ll))

typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;

using ll=long long;

template<typename T>
struct BIT{
    vector<T> h;
    int n;
    BIT(int sz):n(sz),h(sz+1,0){}
    void update(int x,T val){
        while(x<=n){
            h[x]+=val;
            x+=lowbit(x);
        }
    }
    T query(int x){
        T res=0;
        while(x>0){
            res+=h[x];
            x-=lowbit(x);
        }
        return res;
    }
};

ull n,a[(int)5e5+9],dfn_cnt,presum,ans,res[(int)5e5+9];
BIT<int> bit(int(1e6+1));
BIT<ull> bit_sq(int(1e6+1));

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin>>n;
    vector g(n+1,vector<int>());
    for(int i=1,x,y;i<n;++i){
        cin>>x>>y;
        g[x].push_back(y),g[y].push_back(x);
    }
    for(int i=1;i<=n;++i)cin>>a[i];

    vector<int> son(n+1,0), sz(n+1,0), dfn(n+1,0), ori(n+1,0);  
    auto dfs=[&](const auto& self, const vector<vector<int>>&g, int u, int fa)->void{
        sz[u]=1, dfn[u]=++dfn_cnt, ori[dfn[u]]=u;

        for(auto it:g[u])if(it!=fa){
            self(self, g, it, u);
            sz[u]+=sz[it];
            son[u]=sz[son[u]]<sz[it]?it:son[u];
        }
    };

    auto dsu=[&](const auto& self, const vector<vector<int>>&g, int u, int fa,int sta)->void{

        auto cal=[&](ull cor)->void{
            ull f = bit.query(cor);
            ull e = bit_sq.query(int(1e6))-bit_sq.query(cor);
            ans+=(cor*cor*f+e-cor*presum)*ull(2);
            bit.update(cor,1),presum+=cor;
            bit_sq.update(cor,cor*cor);
        };

        auto del=[&](ull cor)->void{
            bit.update(cor,-1);
            bit_sq.update(cor,-cor*cor);
            presum -= cor;
        };

        for(auto it:g[u])if(it!=fa&&it!=son[u]){
            self(self, g, it, u, 0);
        }
        if(son[u])self(self, g, son[u], u, 1);
        for(auto it:g[u])if(it!=fa&&it!=son[u]){
            for(int i=dfn[it];i<dfn[it]+sz[it];++i){
                cal(a[ori[i]]);
            }
        }
        cal(a[u]);

        res[u]=ans;


        if(!sta){
            for(int i=dfn[u];i<sz[u]+dfn[u];++i)del(a[ori[i]]);
            ans=0;
        }
    };

    dfs(dfs,g,1,0);
    dsu(dsu,g,1,0,0);
    for(int i=1;i<=n;++i)ans^=res[i];
    cout<<ans;
    return 0;
}
此文章已被阅读次数:正在加载...更新于