[Baekjoon] 1197. 최소 스패닝 트리

2 minute read

문제 설명

문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

입력

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

출력

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.

예제 입력 1

3 3
1 2 1
2 3 2
1 3 3

예제 출력 1

3

출처

알고리즘 분류


문제 풀이

# Graph # Tree # MST

Graph/Tree 개념을 사용하는 MST 문제입니다.


풀이 과정

어제 ‘최소 신장 트리’에 관한 내용들을 정리했고, 오늘 이를 직접 코드로 구현해보기 위해 이 문제를 풀었습니다.

특별한 점은 없고, MST 알고리즘인 Kruskal과 Prim으로 모두 구현해보았습니다.

참조: 최소 신장 트리 알아보기


전체 코드

Kruskal 알고리즘(380ms)

import sys
input = sys.stdin.readline

# kruskal (380ms)
def find(x):
    if parent[x] < 0:
        return x
    p = parent[x]
    root = find(p)
    return root
def union(x, y):
    px = find(x)
    py = find(y)
    if px == py:
        return False
    if parent[px] < parent[py]:
        parent[px] += parent[py]
        parent[py] = px
    else:
        parent[py] += parent[px]
        parent[px] = py
    return True
def kruskal(_edges):
    edges = sorted(_edges,key=lambda x:x[2])
    mst = set()
    total_weight = 0
    for u, v, w in edges:
        if union(u, v):
            mst.add((u,v))
            total_weight += w
        if len(mst) == V-1:
            break
    return total_weight

V, E = map(int, input().rstrip().split())
edges = [list(map(int,input().rstrip().split())) for _ in range(E)]
parent = [-1 for _ in range(V+1)]
print(kruskal(edges))

Prim 알고리즘(500ms)

import sys
input = sys.stdin.readline

# prim (500ms)
from heapq import heapify, heappush, heappop
def prim(graph, start):
    visited = [False for _ in range(V+1)]
    visited[start] = True
    heap = graph[start]
    heapify(heap)
    mst = set()
    total_weight = 0
    while heap:
        w, u, v = heappop(heap)
        if visited[v]:
            continue
        visited[v] = True
        mst.add((u,v))
        total_weight += w
        for nw, nu, nv in graph[v]: # v = nu
            if not visited[nv]:
                heappush(heap, (nw,nu,nv))
        if len(mst) == V-1:
            break
    return total_weight

V, E = map(int, input().rstrip().split())
graph = {i:[] for i in range(1,V+1)}
for _ in range(E):
    u, v, w = map(int, input().rstrip().split())
    graph[u].append((w,u,v))
    graph[v].append((w,v,u))
print(prim(graph, 1))


배운 점

  • 최소 신장 트리 알고리즘
    • Kruskal
    • Prim

Leave a comment