dsu on tree

#include <cstdio>
const int maxn = 1000005;
int n, col[maxn], qry[maxn], ans[maxn];
int m, ter[maxn], nxt[maxn], lnk[maxn];
int skip, cnt[maxn], sz[maxn], ch[maxn];
void addedge(int u, int v) {
    ter[++m] = v;
    nxt[m] = lnk[u];
    lnk[u] = m;
}
void gsz(int u) {
    sz[u] = 1;
    for (int i = lnk[u]; i; i = nxt[i]) {
        gsz(ter[i]);
        sz[u] += sz[ter[i]];
        if (sz[ter[i]] > sz[ch[u]]) {
            ch[u] = ter[i];
        }
    }
}
void edt(int u, int v) {
    cnt[col[u]] += v;
    for (int i = lnk[u]; i; i = nxt[i]) {
        if (ter[i] != skip) {
            edt(ter[i], v);
        }
    }
}
void dfs(int u, bool flag = 0) {
    int son = 0;
    for (int i = lnk[u]; i; i = nxt[i]) {
        if (ter[i] != ch[u]) {
            dfs(ter[i]);
        }
    }
    if (ch[u]) {
        skip = ch[u];
        dfs(ch[u], 1);
    }
    edt(u, 1);
    ans[u] = cnt[qry[u]];
    skip = 0;
    if (flag == 0) {
        edt(u, -1);
    }
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &col[i]);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", &qry[i]);
    }
    for (int u, i = 2; i <= n; i++) {
        scanf("%d", &u);
        addedge(u, i);
    }
    gsz(1);
    dfs(1);
    for (int i = 1; i <= n; i++) {
        printf("%d\n", ans[i]);
    }
    return 0;
}