본문 바로가기

Algorithms & Languages/파이썬 알고리즘 문제풀이

[Python/백준] 1717. 집합의 표현

문제 접근 및 공부 내용, 풀이는 모두 하단 코드에 "주석"으로 포함되어 있으니 참고해주세요.

문제 유형 보기

더보기
더보기
유니온 파인드

https://www.acmicpc.net/problem/1717


풀이

import sys
input = sys.stdin.readline
# 백준 때문에
sys.setrecursionlimit(10 ** 6)

n, m = map(int, input().split())
parent = list(range(n+1))
rank = [0] * (n + 1)

def find(x):
    if parent[x] != x:
        # 만약 루트 노드가 아니면
        # -> 루트 노드를 찾아가서, 그 루트 노드의 parent 즉 루트 노드 그 자체를 return한 걸
        # 계속 끌고 오며 갱신하며 가져옴.
        parent[x] = find(parent[x])
    return parent[x]

def union(x, y):
    x_root = find(x)
    y_root = find(y)

    if x_root == y_root:
        # 이미 같은 집합에 속해 있으면, Union하지 않음.
        return

    # if x_root < y_root:
    #     parent[y_root] = x_root
    #     return
    # parent[x_root] = y_root
    # 이거 삼항으로도 구현 가능

    # 위 부분이 문제임.
    if rank[x_root] > rank[y_root]:
        # y_root가 x_root에 붙어야 함.
        parent[y_root] = x_root

    elif rank[x_root] < rank[y_root]:
        parent[x_root] = y_root
    else:
        # 아무거나 나머지에 붙이면 됨.
        parent[y_root] = x_root
        rank[x_root] += 1

    # 이제, 만약 경로 압축이 없고 랭크만 있으면, 시간복잡도는 O(log(n))임.
    # 트리 높이만큼 들어가니까.
    # 근데, rank도 있고 경로압축도 있으면, 거의 O(1)

for _ in range(m):
    command, a, b = map(int, input().split())

    # Union
    if command == 0:
        union(a, b)
        continue

    # Find
    print("YES" if find(a) == find(b) else "NO")

https://cuffyluv.tistory.com/181

위 정리본도 참고.


250227 재풀이

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10 ** 6)

# n + 1개의 집합 0 ~ ㅜ
# 합집합이랑 find 할 건데
# 입력 : n, m(연산개수)
# 0 a b -> a와 b Union하겠다
# 1 a b -> a와 b 같은 집합인지 확인하겠다.
# a b 같은 집합이면 YES 아니면 NO 한 줄씩 출력

def find(x):
    if parent[x] != x:
        parent[x] = find(parent[x])
    return parent[x]

def union(x, y):
    x_root = find(x)
    y_root = find(y)

    if x_root == y_root:
        return

    if rank[x_root] > rank[y_root]:
        parent[y_root] = x_root
    elif rank[y_root] > rank[x_root]:
        parent[x_root] = y_root
    else:
        parent[y_root] = x_root
        rank[x_root] += 1

n, m = map(int, input().split())
parent = list(range(n + 1))
rank = [0] * (n + 1)

for _ in range(m):
    operation, a, b = map(int, input().split())

    if operation == 0:
        union(a, b)
        continue

    # if operation == 1:
    if find(a) == find(b):
        print('YES')
    else:
        print('NO')

    # print('YES' if find(a) == find(b) else 'NO')