Union Find
2가지 연산이 있다.
- group(A): 정점 A가 속한 그룹의 대표
int r[1000];
int group(int node)
{
if(!r[node]) return node;
return r[node] = group(r[node]);
}
- join(A, B): 정점 A가 속한 그룹과 정점의 B가 속한 그룹을 합친다.
void join(int A, int B)
{
int ga = group(A);
int gb = group(B);
if(ga==gb) return; // do nothing
r[ga]=gb;
}
Kruskal
최소 신장 트리 (Minimum Spanning Tree)를 구하는 알고리즘. 간선들을 가중치 오름차순으로 정렬한 후, 첫번째 간선부터 보면서 현재 간선이 사이클을 만들지 않으면 신장 트리에 포함시킨다.
sort(edges);
for(each edge)
if(!connected(edge.s, edge.e))
connect(edge.s, edge.e), ans += edge.w;
- 문제번호
1. 1717
#include <stdio.h>
int r[1000001];
int size[1000001];
int find_father(int p) {
if (r[p] == p) return p;
return r[p] = find_father(r[p]);
}
void Union(int p1, int p2) {
find_father(p1);
find_father(p2);
if (size[r[p1]]>size[r[p2]]) r[r[p2]] = r[p1];
else r[r[p1]] = r[p2];
size[r[p1]] = size[p1] + size[p2];
}
int check(int p1, int p2) {
return find_father(p1) == find_father(p2) ? 1 : 0;
}
int main() {
int n, m, type, set1, set2;
scanf("%d %d", &n, &m);
for (int i = 0; i <= n; i++)
{
r[i] = i;
size[i] = 1;
}
for (int i = 0; i < m; i++)
{
scanf("%d %d %d", &type, &set1, &set2);
if (type == 0) Union(set1, set2);
else check(set1, set2)==1?printf("YES\n"):printf("NO\n");
}
}
2. 1197
#define _CRT_SECURE_NO_WARNINGS
#include <stdio.h>
typedef enum { false, true } bool;
int cnt = 0;
struct Node {
int p1;
int p2;
int value;
};
struct Node line[100001];
int r[10001];
int size[10001];
int chk = 0;
void qsort(int start, int end) {
int i = start;
int j = end;
int tmp;
int mid = (i + j) / 2;
int pivot = line[mid].value;
while (i <= j) {
while (line[i].value < pivot) i++;
while (line[j].value > pivot) j--;
if (i <= j)
{
tmp = line[i].value;
line[i].value = line[j].value;
line[j].value = tmp;
tmp = line[i].p1;
line[i].p1 = line[j].p1;
line[j].p1 = tmp;
tmp = line[i].p2;
line[i].p2 = line[j].p2;
line[j].p2 = tmp;
i++;
j--;
}
}
if (start < j ) qsort(start, j);
if (i < end) qsort(i, end);
}
int find_father(int p) {
if (r[p] == p) return p;
return r[p] = find_father(r[p]);
}
void join(int p1, int p2) {
find_father(p1);
find_father(p2);
if (size[r[p1]]>size[r[p2]]) r[r[p2]] = r[p1];
else r[r[p1]] = r[p2];
size[r[p1]] = size[r[p1]] + size[r[p2]];
}
bool conn(int p1, int p2) {
find_father(p1);
find_father(p2);
return r[p1] == r[p2] ? true : false;
}
int main() {
int V, E;
int Answer = 0;
scanf("%d %d", &V, &E);
for (int i = 1; i <= E; i++)
{
scanf("%d %d %d", &line[i].p1, &line[i].p2, &line[i].value);
}
for (int i = 1; i <= V; i++)
{
r[i] = i;
size[i] = 1;
}
qsort(1, E);
for (int i = 1; i <= E; i++)
{
if (!conn(line[i].p1, line[i].p2)) {
Answer += line[i].value;
join(line[i].p1, line[i].p2);
cnt++;
if (cnt == V - 1) break;
}
}
printf("%d\n", Answer);
}
3. 1922
#define _CRT_SECURE_NO_WARNINGS
#include <stdio.h>
typedef enum { false, true } bool;
int cnt = 0;
struct Node {
int p1;
int p2;
int value;
};
struct Node line[100001];
int r[10001];
int size[10001];
int chk = 0;
void qsort(int start, int end) {
int i = start;
int j = end;
int tmp;
int mid = (i + j) / 2;
int pivot = line[mid].value;
while (i <= j) {
while (line[i].value < pivot) i++;
while (line[j].value > pivot) j--;
if (i <= j)
{
tmp = line[i].value;
line[i].value = line[j].value;
line[j].value = tmp;
tmp = line[i].p1;
line[i].p1 = line[j].p1;
line[j].p1 = tmp;
tmp = line[i].p2;
line[i].p2 = line[j].p2;
line[j].p2 = tmp;
i++;
j--;
}
}
if (start < j ) qsort(start, j);
if (i < end) qsort(i, end);
}
int find_father(int p) {
if (r[p] == p) return p;
return r[p] = find_father(r[p]);
}
void join(int p1, int p2) {
find_father(p1);
find_father(p2);
if (size[r[p1]]>size[r[p2]]) r[r[p2]] = r[p1];
else r[r[p1]] = r[p2];
size[r[p1]] = size[r[p1]] + size[r[p2]];
}
bool conn(int p1, int p2) {
find_father(p1);
find_father(p2);
return r[p1] == r[p2] ? true : false;
}
int main() {
int V, E;
int Answer = 0;
scanf("%d %d", &V, &E);
for (int i = 1; i <= E; i++)
{
scanf("%d %d %d", &line[i].p1, &line[i].p2, &line[i].value);
}
for (int i = 1; i <= V; i++)
{
r[i] = i;
size[i] = 1;
}
qsort(1, E);
for (int i = 1; i <= E; i++)
{
if (!conn(line[i].p1, line[i].p2)) {
Answer += line[i].value;
join(line[i].p1, line[i].p2);
cnt++;
if (cnt == V - 1) break;
}
}
printf("%d\n", Answer);
}
4. 2887
아이디어는 x, y, z좌표를 각각 순서대로 정렬하면(대신 index는 보존)
각각의 x,y,z값은 다음 값과의 차가 최소 거리이다.는게 포인트!!
따라서 최소 엣지는 정렬한 x중 x1-x2, x2-x3 .... 총 O(N-1)
마찬가지로 y, z도 O(N-1)
O(3*(NlogN + N-1))이므로 O(NlogN+N)에 끝남
#define _CRT_SECURE_NO_WARNINGS
#include <stdio.h>
#define MAX 2000000000
struct pt {
int i;
int v;
};
struct Edge {
int start;
int end;
int v;
};
int N;
struct pt point[300003];
struct Edge edge[300003];
void q_sort(struct pt *arr, int start, int end) {
int i = start;
int j = end;
struct pt tmp;
int pivot = arr[(start + end) / 2].v;
while (i <= j)
{
while (arr[i].v < pivot) i++;
while (arr[j].v > pivot) j--;
if (i <= j) {
tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
i++;
j--;
}
}
if (start < j) q_sort(arr, start, j);
if (i < end) q_sort(arr, i, end);
}
void q_sort_e(struct Edge *arr, int start, int end) {
int i = start;
int j = end;
struct Edge tmp;
int pivot = arr[(start + end) / 2].v;
while (i <= j)
{
while (arr[i].v < pivot) i++;
while (arr[j].v > pivot) j--;
if (i <= j) {
tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
i++;
j--;
}
}
if (start < j) q_sort_e(arr, start, j);
if (i < end) q_sort_e(arr, i, end);
}
void mk_node(int start, int end) {
edge[start].start = point[start].i;
edge[start].end = point[end].i;
edge[start].v = point[start].v < point[end].v ? point[end].v - point[start].v : point[start].v - point[end].v;
}
int r[100001];
int size[100001];
int update(int a) {
if (r[a] == a) return a;
r[a] = update(r[a]);
return r[a];
}
int conn(int a, int b) {
update(a);
update(b);
return r[a] == r[b] ? 1 : 0;
}
void join(int a, int b) {
if (size[a]>size[b])
{
r[r[b]] = r[a];
size[a] += size[r[b]];
}
else
{
r[r[a]] = r[b];
size[b] += size[r[a]];
}
}
int main() {
scanf("%d", &N);
for (int i = 1; i <= N; i++)
{
point[i].i = point[i + N].i = point[i + N * 2].i = i;
scanf("%d %d %d", &point[i].v, &point[i + N].v, &point[i + 2 * N].v);
}
q_sort(point, 1, N);
q_sort(point, N + 1, 2 * N);
q_sort(point, 2 * N + 1, 3 * N);
char audit;
//for (int i = 1; i <= 3 * N; i++)
//{
// if (i <= N) audit = 'x';
// else if (i <= 2 * N) audit = 'y';
// else audit = 'z';
// printf("%c %d %d\n", audit, point[i].i, point[i].v);
//}
for (int i = 1; i < N; i++)
{
mk_node(i, i + 1); //edge[1] ~ edge[N-1]
mk_node(i + N, i + 1 + N); //edge[N+1] ~ edge[2N-1]
mk_node(i + 2 * N, i + 1 + 2 * N); //edge[2N+1] ~ edge[3N-1]
}
edge[N].v = MAX;
edge[2 * N].v = MAX;
edge[3 * N].v = MAX;
q_sort_e(edge, 1, 3 * N - 1);
//for (int i = 1; i <= 3 * N - 1; i++)
//{
// printf("%d %d %d\n", edge[i].start, edge[i].end, edge[i].v);
//}
//Kruskal Algorithm start
for (int i = 1; i <= N; i++)
{
r[i] = i;
size[i] = 1;
}
int Answer = 0;;
int cnt = 0;
for (int i = 1; i <= 3*N-3; i++)
{
if (!conn(edge[i].start, edge[i].end)) {
join(edge[i].start, edge[i].end);
Answer += edge[i].v;
cnt++;
}
if (cnt == N - 1) break;
}
printf("%d", Answer);
}
5. 1626
Kruskal을 통해 MST를 구하고, 각 노드들을 하나씩 빼면서 새로 MST를 구하려했는데 두가지 문제점이 생겼다. TLE가 계속뜬다. 그래서 생각해본 문제점은
1. 노드들을 빼는 작업이 너모 오래걸린다.
- 노드를 제거하면 두개의 트리가 생기고, 이때 각 트리의 root를 새로 지정해 줘야 나중에 사이클 체크를 할 수 있는데, root를 새로 지정하는게 넘오래걸림..
2. 결국 시간복잡도가 문제인데, 효율적으로 제거하는 방법을 모른채, 초기값으로 새로 세팅해서 탐색하면 세팅하는데 O(V) , 탐색하는데 O((V-1)*E)) ≤ O(V^3) ..덜ㄷ럳러
Kruskal하는데 걸리는 시간이 O(ElogE + E)였는데 이거보다 새로탐색하는데 더 오래걸리니.. TLE가 뜨는게 아닌가 싶은데? 저렇게 무식하게 새로 셋팅하는거 말고.. 기존 MST를 가지고 하는 방법 뭐 없나..
6. 7627
7. 3697