본문 바로가기

Algorithms & Languages/파이썬 알고리즘 문제풀이

[Python/DailyAlgo, 백준] 86. 최소 신장 트리, 1197. 최소 스패닝 트리, 89. 동기화 서버 구축

문제 접근 및 공부 내용, 풀이는 모두 하단 코드에 "주석"으로 포함되어 있으니 참고해주세요.

문제 유형 보기

더보기
크루스칼 알고리즘

https://dailyalgo.kr/ko/problems/86

https://www.acmicpc.net/problem/1197

https://dailyalgo.kr/ko/problems/89


풀이

세 문제는 사실상 동일한 문제임.

86. 최소 신장 트리

더보기
def solution(n, edges):

    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]

    def union(x, y):
        x_root = find(x)
        y_root = find(y)

        if x_root == y_root:
            return False
    
        if rank[x_root] > rank[y_root]:
            parent[y_root] = x_root
        elif rank[y_root] > rank[x_root]:
            parent[x_root] = y_root
        else:
            parent[y_root] = x_root
            rank[x_root] += 1

        return True

    total = 0
    counts = 0
    parent = list(range(n + 1))
    rank = [0] * (n + 1)
    edges.sort(key = lambda x:x[2])
    for x, y, w in edges:
        if union(x, y):
            # 사이클 안 만들면 True, 사이클 만들면 False가 되게 해,
            # 사이클을 안 만드는 경우만 해당 간선을 추가함.
            total += w
            counts += 1
            if counts == n - 1:
                # 트리의 특성을 활용해, 모든 정점을 다 뽑았는지 확인
                # 트리의 간선 개수 = 정점 개수 - 1
                break

    return total

강사님 풀이

def solution(n, edges):
    def find(x):
        if x != parent[x]:
            parent[x] = find(parent[x])
        return parent[x]

    # 주의! range(1, n + 1)하면 그냥 n개가 나오는 거임.
    # 우리는 0번 버리고 1 ~ n 정점 번호로 쓸 거니까 range(n + 1)만 해야지.
    

    def union(x, y):
        x_root = find(x)
        y_root = find(y)

        if x_root == y_root:
            return False

        if rank[x_root] > rank[y_root]:
            parent[y_root] = x_root
        elif rank[x_root] < rank[y_root]:
            parent[x_root] = y_root
        else:
            parent[x_root] = y_root
            rank[y_root] += 1
        
        return True
    
    parent = list(range(n + 1))
    rank = [0] * (n + 1)

    # 이거 파이썬 소트는 팀 소트니까, 시간복잡도 병합 정렬처럼 O(nlog(n))임.
    # 이 때, n은 간선 개수니까, 시간복잡도는 O(Elog(E))임.
    edges.sort(key = lambda x: x[2])
    
    total = 0 # MST의 간선 가중치 합
    counts = 0 # 현재까지 뽑은 간선의 개수

    # 이거 시간복잡도는 그냥 O(E)
    for x, y, w in edges:
        if union(x, y):
            # 이 (x, y) 간선을 추가해도 사이클이 생기지 않는다면
            total += w
            counts += 1

            if counts == n - 1:
                break

    return total

1197. 최소 스패닝 트리

더보기
import sys
input = sys.stdin.readline

def find(x):
    if x != parent[x]:
        parent[x] = find(parent[x])
    return parent[x]

def union(x, y):
    x_root = find(x)
    y_root = find(y)

    if x_root == y_root:
        # 싸이클을 만들어버린 거임.
        return False

    if rank[x_root] > rank[y_root]:
        parent[y_root] = x_root
    elif rank[y_root] > rank[x_root]:
        parent[x_root] = y_root
    else:
        parent[y_root] = x_root
        rank[x_root] += 1

    return True

# 스패닝 트리 구현 문제.
# ElogE < 1억

v, e = map(int, input().split())

# 일단, edges를 가중치 순으로 오름차순 정렬해야함.
# 그러니 일단 edges 입력 받자.
edges = [list(map(int, input().split())) for _ in range(e)]
edges.sort(key = lambda x:x[2])

counts = 0
total = 0

parent = list(range(v + 1))
rank = [0] * (v + 1)

# 엣지를 앞에서부터 순서대로 하나씩 뽑아서,
# 1. 사이클을 만들지 않았으면 뽑고 아니면 무시
# 2. 간선을 정점개수 - 1개만큼 뽑았으면 빠져나오고 아니면 계속
for x, y, w in edges:
    # 싸이클을 만들면 이 간선 넣지 말란 의미로 False,
    # 싸이클을 만들지 않으면 이 간선 넣어도 된단 의미로 True를 반환할 거임.
    if union(x, y):
        counts += 1
        total += w
        if counts == v - 1:
            break

print(total)

250227 재풀이

import sys
input = sys.stdin.readline

def find(x):
    if parent[x] != x:
        parent[x] = find(parent[x])
    return parent[x]

def union(x, y):
    x_root = find(x)
    y_root = find(y)

    if x_root == y_root:
        return False

    if rank[x_root] > rank[y_root]:
        parent[y_root] = x_root
    elif rank[y_root] > rank[x_root]:
        parent[x_root] = y_root
    else:
        parent[y_root] = x_root
        rank[x_root] += 1

    return True

v, e = map(int, input().split())
edges = []
for _ in range(e):
    edges.append(tuple(map(int, input().split())))

edges.sort(key = lambda x:x[2])

parent = list(range(v + 1))
rank = [0] * (v + 1)
counts = 0
total = 0
for x, y, w in edges:
    if union(x, y):
        counts += 1
        total += w
        if counts == v - 1:
            break

print(total)

89. 동기화 서버 구축

더보기
def solution(n, syncs):

    # 전형적인 MST 문제다!
    # MST를 구할 수 없으면 -1을 반환해야 하는 것에 유의.
    # => 반복문 바깥 맨 마지막에 return -1 하면 될듯.
    # ElogE = 50만 * 5.xx = 대략 커봐야 250만 정도 -> 뭐 괜찮을듯?

    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]

    def union(x, y):
        x_root = find(x)
        y_root = find(y)

        if x_root == y_root:
            # 싸이클 만들어졌으니, MST에 넣으면 안돼!!
            return False

        if rank[x_root] > rank[y_root]:
            parent[y_root] = x_root
        elif rank[y_root] > rank[x_root]:
            parent[x_root] = y_root
        else:
            parent[y_root] = x_root
            rank[x_root] += 1

        return True


    parent = list(range(n + 1))
    rank = [0] * (n + 1)

    counts = 0
    total = 0
    syncs.sort(key = lambda x:x[2])
    for x, y, w in syncs:
        if union(x, y):
            counts += 1
            total += w

            if counts == n - 1:
                return total

    return -1

250227 재풀이

def solution(n, edges):

    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]

    def union(x, y):
        x_root = find(x)
        y_root = find(y)

        if x_root == y_root:
            # 싸이클을 만들 수 있는 상황이면,
            # 그 간선을 선택하지 말라고 False를 반환
            return False
        
        if rank[x_root] > rank[y_root]:
            parent[y_root] = x_root
        elif rank[y_root] > rank[x_root]:
            parent[x_root] = y_root  
        else:
            parent[y_root] = x_root
            rank[x_root] += 1

        return True

    parent = list(range(n + 1))
    rank = [0] * (n + 1)
    counts = 0
    total = 0
    edges.sort(key = lambda x:x[2])
    for x, y, w in edges:
        if union(x, y):
            # 사이클을 안 만들면:
            counts += 1
            total += w
            if counts == n - 1:
                return total

    return -1