纪念一道缩点的题目

3个点互相到就可以缩,有向图

#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> P;
const int maxn = 305;
bitset<maxn> in[maxn], out[maxn], del, temp;
priority_queue<P, vector<P>, greater<P> > que;
int fa[maxn];
int n, m;
int dis[maxn];
int findfa(int t){
    if(t == fa[t])    return fa[t];
    else    return fa[t] = findfa(fa[t]);
}
//被缩掉的点需要修改为他的fa 
void correct(bitset<maxn> &t){
    for(int i = 1; i <= n; i++){
        if(del[i] == 1){
            int f = findfa(i);
            if(t[i] == 1){
                t[i] = 0;
                t[f] = 1;
            }
        }    
    }
} 
//这个用来修改点in和out里面的值
//被缩掉的点需要修改为他的fa 
void clear(int t){
    correct(in[t]);
    correct(out[t]);
    in[t].reset(t);
    out[t].reset(t);
}
//这个是用来缩点的 
void merge(int x, int y){
    in[x] |= in[y];
    out[x] |= out[y];
    del.set(y);
    fa[findfa(y)] = findfa(x);
}
//类似于网络流一直找増广路的思路
//虽然看不到复杂度上限,但是上限其实极低 
//用bitset来存图,点必须在300个以内 
//思路就是每个点有一个in和out,外加全局的del(被缩掉的点集合) 
bool shrink(){
    bool res = false;
    for(int i = 1; i <= n; i++){
        if(del[i])    continue;
        for(int j = 1; j <= n; j++){
            if(del[j])    continue;
            if(!out[i][j])    continue;
            if(in[i][j]){
                merge(i, j);
                clear(i);
                res = true;
            }else{
                correct(in[i]);
                correct(out[j]);
                temp = out[j] & in[i];
                if(temp.any()){
                    merge(i, j);
                    for(int h = 1; h <= n; h++){
                        if(temp[h])    merge(i, h);
                    }
                    clear(i);
                    res = true;
                }
            }
        }
    }
    return res;
}
int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= m; i++){
        int u, v;
        scanf("%d%d", &u, &v);
        out[u].set(v);
        in[v].set(u);
    }
    for(int i = 1; i <= n; i++){
        fa[i] = i;
        dis[i] = 1e9+7;
    }
    bool res = true;
    while(res){
        res = shrink();
    }
    int fa1 = findfa(1);
    dis[fa1] = 0;
    que.push(P(0, fa1));
    while(que.size()){
        P p = que.top();
        que.pop();
        if(dis[p.second] < p.first)    continue;
        for(int i = 1; i <= n; i++){
            if(del[i])    continue;
            if(out[p.second][i] == 0)    continue;
            if(dis[i] > p.first + 1){
                dis[i] = p.first + 1;
                que.push(P(dis[i], i));
            }
        }
    }
    for(int i = 1; i <= n; i++){
        int f = findfa(i);
        if(dis[f] >= 1e9)    printf("-1 ");
        else    printf("%d ", dis[f]);
    }
    return 0;
}