dsu on tree套个BIT
#include <bits/stdc++.h>
/*
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
*/
using namespace std;
const double eps = 1e-10;
const double pi = 3.1415926535897932384626433832795;
const double eln = 2.718281828459045235360287471352;
#define f(i, a, b) for (int i = a; i <= b; i++)
#define scan(x) scanf("%d", &x)
#define mp make_pair
#define pb push_back
#define lowbit(x) (x&(-x))
#define fi first
#define se second
#define SZ(x) int((x).size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define summ(a) (accumulate(all(a), 0ll))
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;
using ll=long long;
template<typename T>
struct BIT{
vector<T> h;
int n;
BIT(int sz):n(sz),h(sz+1,0){}
void update(int x,T val){
while(x<=n){
h[x]+=val;
x+=lowbit(x);
}
}
T query(int x){
T res=0;
while(x>0){
res+=h[x];
x-=lowbit(x);
}
return res;
}
};
ull n,a[(int)5e5+9],dfn_cnt,presum,ans,res[(int)5e5+9];
BIT<int> bit(int(1e6+1));
BIT<ull> bit_sq(int(1e6+1));
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n;
vector g(n+1,vector<int>());
for(int i=1,x,y;i<n;++i){
cin>>x>>y;
g[x].push_back(y),g[y].push_back(x);
}
for(int i=1;i<=n;++i)cin>>a[i];
vector<int> son(n+1,0), sz(n+1,0), dfn(n+1,0), ori(n+1,0);
auto dfs=[&](const auto& self, const vector<vector<int>>&g, int u, int fa)->void{
sz[u]=1, dfn[u]=++dfn_cnt, ori[dfn[u]]=u;
for(auto it:g[u])if(it!=fa){
self(self, g, it, u);
sz[u]+=sz[it];
son[u]=sz[son[u]]<sz[it]?it:son[u];
}
};
auto dsu=[&](const auto& self, const vector<vector<int>>&g, int u, int fa,int sta)->void{
auto cal=[&](ull cor)->void{
ull f = bit.query(cor);
ull e = bit_sq.query(int(1e6))-bit_sq.query(cor);
ans+=(cor*cor*f+e-cor*presum)*ull(2);
bit.update(cor,1),presum+=cor;
bit_sq.update(cor,cor*cor);
};
auto del=[&](ull cor)->void{
bit.update(cor,-1);
bit_sq.update(cor,-cor*cor);
presum -= cor;
};
for(auto it:g[u])if(it!=fa&&it!=son[u]){
self(self, g, it, u, 0);
}
if(son[u])self(self, g, son[u], u, 1);
for(auto it:g[u])if(it!=fa&&it!=son[u]){
for(int i=dfn[it];i<dfn[it]+sz[it];++i){
cal(a[ori[i]]);
}
}
cal(a[u]);
res[u]=ans;
if(!sta){
for(int i=dfn[u];i<sz[u]+dfn[u];++i)del(a[ori[i]]);
ans=0;
}
};
dfs(dfs,g,1,0);
dsu(dsu,g,1,0,0);
for(int i=1;i<=n;++i)ans^=res[i];
cout<<ans;
return 0;
}