[Baekjoon] 1197. 최소 스패닝 트리
문제 설명
문제
그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.
최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.
입력
첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.
그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.
출력
첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.
예제 입력 1
3 3
1 2 1
2 3 2
1 3 3
예제 출력 1
3
출처
- 문제의 오타를 찾은 사람: BaaaaaaaaaaarkingDog
- 데이터를 추가한 사람: djm03178, ediya
- 빠진 조건을 찾은 사람: tjrwodnjs999
알고리즘 분류
문제 풀이
# Graph # Tree # MST
Graph/Tree
개념을 사용하는 MST
문제입니다.
풀이 과정
어제 ‘최소 신장 트리’에 관한 내용들을 정리했고, 오늘 이를 직접 코드로 구현해보기 위해 이 문제를 풀었습니다.
특별한 점은 없고, MST 알고리즘인 Kruskal과 Prim으로 모두 구현해보았습니다.
참조: 최소 신장 트리 알아보기
전체 코드
Kruskal 알고리즘(380ms)
import sys
input = sys.stdin.readline
# kruskal (380ms)
def find(x):
if parent[x] < 0:
return x
p = parent[x]
root = find(p)
return root
def union(x, y):
px = find(x)
py = find(y)
if px == py:
return False
if parent[px] < parent[py]:
parent[px] += parent[py]
parent[py] = px
else:
parent[py] += parent[px]
parent[px] = py
return True
def kruskal(_edges):
edges = sorted(_edges,key=lambda x:x[2])
mst = set()
total_weight = 0
for u, v, w in edges:
if union(u, v):
mst.add((u,v))
total_weight += w
if len(mst) == V-1:
break
return total_weight
V, E = map(int, input().rstrip().split())
edges = [list(map(int,input().rstrip().split())) for _ in range(E)]
parent = [-1 for _ in range(V+1)]
print(kruskal(edges))
Prim 알고리즘(500ms)
import sys
input = sys.stdin.readline
# prim (500ms)
from heapq import heapify, heappush, heappop
def prim(graph, start):
visited = [False for _ in range(V+1)]
visited[start] = True
heap = graph[start]
heapify(heap)
mst = set()
total_weight = 0
while heap:
w, u, v = heappop(heap)
if visited[v]:
continue
visited[v] = True
mst.add((u,v))
total_weight += w
for nw, nu, nv in graph[v]: # v = nu
if not visited[nv]:
heappush(heap, (nw,nu,nv))
if len(mst) == V-1:
break
return total_weight
V, E = map(int, input().rstrip().split())
graph = {i:[] for i in range(1,V+1)}
for _ in range(E):
u, v, w = map(int, input().rstrip().split())
graph[u].append((w,u,v))
graph[v].append((w,v,u))
print(prim(graph, 1))
배운 점
- 최소 신장 트리 알고리즘
- Kruskal
- Prim
Leave a comment