백준 1043, Disjoint Set
Disjoint set, 혹은 Union-find 로 더 잘 알려진 도구를 써서 푸는 문제다.
BOJ 에서 실전으로 문제를 푼 것이 매우 오래 된 것이 생각나서 랜덤으로 문제를 골라서 풀어봤는데,
disjoint set을 쓰는 문제인 것을 알았는데도 한 번에 풀 수가 없었다.
옛날에 내 깃헙에 적어놓은 disjoint set 템플릿도 찾지를 못했고,
바닥부터 구현을 하려고 했을때도 기억이 하나도 안 났다.
신년에 세운 목표 지키기가 매우 어렵다.
진실을 알고 있는 사람들과 같이 파티를 한 모든 사람들은 진실을 알게 되므로,
진실을 알고 있는 사람들과 disjoint 한 사람들만 참가 했는지 확인하면 된다.
한 가지 생각 해야 하는 점은, 진실을 아는 사람들은 파티에 참여하는 인원에 따라 바뀐다는 점이다.
다른 말로, 파티에 참여하는 사람들이 적힌 $M$개의 줄 순서가 달라지면 참여 가능 파티 개수도 달라진다.
이를 놓쳐서 계속 틀렸다. 날카로웠던 적이 없었지만 더욱 많이 무뎌졌음이 실감 난다.
Source Code
#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define pll pair<ll, ll>
#define pdd pair<double, double>
#define MAXI 1000000000
#define cplx complex<double>
int N, M, P;
struct disjointset {
int parent[55];
int rank[55];
long long setsize[55];
disjointset() {
for (int i = 0; i < 55; ++i) parent[i] = i;
memset(setsize, 1, sizeof(setsize));
memset(rank, 0, sizeof(rank));
}
int find(int u) {
if (parent[u] == u) return u;
return parent[u] = find(parent[u]);
}
bool merge(int u, int v) {
u = find(u);
v = find(v);
if (u == v) return false;
if (rank[u] > rank[v]) swap(u, v);
setsize[v] += setsize[u];
setsize[u] = 1;
parent[u] = v;
if (rank[u] == rank[v]) ++rank[v];
return true;
}
int size(int u) { return setsize[find(u)]; }
};
vector<int> plist;
vector<vector<int>> party;
disjointset djs;
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
//ifstream cin; cin.open("input.txt");
cin >> N >> M;
cin >> P;
for (int i = 0; i < P; ++i) {
int tmp;
cin >> tmp;
plist.push_back(tmp);
}
for(int i=0;i<M;++i){
party.push_back(vector<int>());
int tmp,tmp2;
bool chk = false;
cin >> tmp;
for(int j=0;j<tmp;++j){
cin >> tmp2;
party[i].push_back(tmp2);
}
for(int j=1;j<tmp;++j){
djs.merge(party[i][0], party[i][j]);
}
}
int cnt=0;
for(int i=0;i<M;++i){
bool chk = false;
for (int j = 0; j < party[i].size(); ++j) {
for(int k=0;k<plist.size();++k){
if(djs.find(party[i][j]) == djs.find(plist[k])){
chk = true;
break;
}
}
}
if(!chk) ++cnt;
}
printf("%d\n", cnt);
return 0;
}