题目链接

#include <bits/stdc++.h>
/*
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
*/
using namespace std;

const double eps = 1e-10;
const double pi = 3.1415926535897932384626433832795;
const double eln = 2.718281828459045235360287471352;

#define f(i, a, b) for (int i = a; i <= b; i++)
#define scan(x) scanf("%d", &x)
#define mp make_pair
#define pb push_back
#define lowbit(x) (x & (-x))

#define fi first
#define se second
#define SZ(x) int((x).size())
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define summ(a) (accumulate(all(a), 0ll))

typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef vector<int> vi;

using ll = long long;

ll n, m, ans[(int)1e5 + 9], c[(int)1e5 + 9], cnt[(int)1e5 + 9], sz[(int)1e5 + 9], son[(int)1e5 + 9], res;
vector<int> g[(int)1e5 + 9];

void dfs(int u, int fa)
{
    for (auto it : g[u])
    {
        if (it != fa)
        {
            dfs(it, u), sz[u] += sz[it], son[u] = sz[son[u]] < sz[it] ? it : son[u];
        }
    }
    sz[u] += 1;
}

void cal(int u, int fa, int val, int bson)
{
    cnt[c[u]] += val;
    if (val == 1)
        res += cnt[c[u]] == 1;
    else
        res -= cnt[c[u]] == 0;
    for (auto it : g[u])
        if (it != fa && it != bson)
            cal(it, u, val, bson);
}

void dfs2(int u, int fa, int opt)
{
    for (auto it : g[u])
    {
        if (it != son[u] && it != fa)
        {
            dfs2(it, u, 1);
        }
    }
    if (son[u])
        dfs2(son[u], u, 0);
    cal(u, fa, 1, son[u]);
    ans[u] = res;
    if (opt)
        cal(u, fa, -1, 0);
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n;
    for (int i = 1, x, y; i < n; ++i)
    {
        cin >> x >> y;
        g[x].push_back(y), g[y].push_back(x);
    }
    f(i, 1, n) cin >> c[i];
    dfs(1, 0);
    dfs2(1, 0, 0);
    cin >> m;
    for (int i = 1, x; i <= m; ++i)
    {
        cin >> x;
        cout << ans[x] << "\n";
    }
    return 0;
}
此文章已被阅读次数:正在加载...更新于