dsu on tree
<pre><code class="language-cpp">#include <cstdio>
const int maxn = 1000005;
int n, col[maxn], qry[maxn], ans[maxn];
int m, ter[maxn], nxt[maxn], lnk[maxn];
int skip, cnt[maxn], sz[maxn], ch[maxn];
void addedge(int u, int v) {
ter[++m] = v;
nxt[m] = lnk[u];
lnk[u] = m;
}
void gsz(int u) {
sz[u] = 1;
for (int i = lnk[u]; i; i = nxt[i]) {
gsz(ter[i]);
sz[u] += sz[ter[i]];
if (sz[ter[i]] > sz[ch[u]]) {
ch[u] = ter[i];
}
}
}
void edt(int u, int v) {
cnt[col[u]] += v;
for (int i = lnk[u]; i; i = nxt[i]) {
if (ter[i] != skip) {
edt(ter[i], v);
}
}
}
void dfs(int u, bool flag = 0) {
int son = 0;
for (int i = lnk[u]; i; i = nxt[i]) {
if (ter[i] != ch[u]) {
dfs(ter[i]);
}
}
if (ch[u]) {
skip = ch[u];
dfs(ch[u], 1);
}
edt(u, 1);
ans[u] = cnt[qry[u]];
skip = 0;
if (flag == 0) {
edt(u, -1);
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &col[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &qry[i]);
}
for (int u, i = 2; i <= n; i++) {
scanf("%d", &u);
addedge(u, i);
}
gsz(1);
dfs(1);
for (int i = 1; i <= n; i++) {
printf("%d\n", ans[i]);
}
return 0;
}</code></pre>