[Tree] ‘최소 신장 트리’에 대한 고찰

2 minute read

‘최소 신장 트리’에 대해 새로운 부분을 발견할 때마다 업데이트합니다.

‘최소 신장 트리’에 대한 고찰

신장 트리

신장 트리그래프 내의 모든 정점을 포함하는 트리이다.

  • 트리의 일종이므로 사이클이 존재하지 않는다.
  • 따라서, 그래프의 정점이 N개 일 때 신장 트리의 간선은 N-1개 이다.

최소 신장 트리

최소 신장 트리는 (무방향) 그래프에 가중치가 있을 때, 간선들의 가중치 합이 최소인 신장트리이다.

  • 그래프 탐색에서는 일반적으로 정보를 graph[u].append((v,w))와 같이 받는다.
  • 최소 신장 트리 문제에서는 Prim 알고리즘Kruskal 알고리즘이 정보를 받는 방식이 다르다.
    • Prim: graph[u].append((w,u,v)) (정점 기준)
    • Kruskal: edges.append((w,u,v)) (간선 기준)


Prim 알고리즘

시작 정점에서 출발하여 정점을 하나씩 선택하며 신장트리 집합을 확장해나가는 방법

  1. 임의의 시작 정점을 하나 정한다. 시작 정점만 포함된 신장 트리 집합을 만든다.
  2. N개의 정점이 모두 선택될 때까지 아래 과정을 반복한다.
    • 신장 트리 집합에 포함된 정점에 인접하면서 아직 방문하지 않은 정점 중, 최소 비용의 간선이 존재하는 정점을 선택한다.
    • 선택한 정점은 신장 트리 집합에 포함시킨다.
  • Prim 알고리즘은 탐색을 진행하면서 가중치가 제일 낮은 간선을 추가하기 때문제 Heap 자료구조를 사용한다.
# ...
# 무방향 그래프 생성
for i in range(M):
    u, v, weight = map(int,input().split())
    graph[u].append([weight, u, v])
    graph[v].append([weight, v, u])

# prim algorithm
def prim(graph, start_node):
    visited = [False for _ in range(N+1)]
    visited[start_node] = 1
    heap = graph[start_node]
    heapq.heapify(candidate)
    mst = set() # mst
    total_weight = 0 # 전체 가중치

    while heap:
        w, u, v = heapq.heappop(heap) # 가중치가 가장 적은 간선 추출
        if not visited[v]: # 방문하지 않았다면
            visited[v] = True # 방문 갱신
            mst.add((u,v)) # mst 삽입
            total_weight += weight # 전체 가중치 갱신

            for edge in graph[v]: # 다음 인접 간선 탐색
                if not visited[edge[2]]: # 방문한 노드가 아니라면, (순환 방지)
                    heapq.heappush(heap, edge) # 우선순위 큐에 edge 삽입
        if len(mst) == N-1: # MST에 N-1개의 간선이 추가되었으면 종료
            break

    return total_weight
# ...


Kruskal 알고리즘

가중치에 따라 정렬된 간선을 하나씩 선택하며 신장 트리 집합을 확장해 나가는 방법

  1. 그래프의 간선들을 가중치를 기준으로 오름차순 정렬한다.
  2. N-1 개의 간선이 선택될 때까지 아래 과정을 반복한다.
    • 가중치가 낮은 간선부터 탐색하면서, 사이클을 형성하지 않으면 간선을 선택한다.
    • 사이클 형성 여부는 union-find 알고리즘을 사용하여 판단한다.
  • Kruskal 알고리즘은 간선 리스트를 가중치를 기준으로 정렬하여, 사이클을 형성하지 않는 간선을 차례대로 선택한다.
# ...
# 간선 리스트 생성
parent = [-1 for _ in range(N+1)] # 유니온파인드
edges = []
for _ in range(M):
    a,b,c = map(int,input().split())
    edges.append((c,a,b))
    
# kruskal algorithm
def kruskal(edges):
    mst = set()
    total_weight = 0
    edges.sort() 
    for edge in edges:
        w,x,y = edge
        if union(x,y): 
            mst.add((x,y))
            total_weight += w
        if len(mst) == N-1: 
            break
    return total_weight

# 유니온 파인드
def find(x):
    if parent[x] < 0:
        return x
    p = find(parent[x])
    parent[x] = p
    return p

def union(x,y):
    x = find(x)
    y = find(y)
    if x == y: 
        return False
    if parent[x] < parent[y]:
        parent[x] += parent[y]
        parent[y] = x
    else:
        parent[y] += parent[x]
        parent[x] = y
    return True
# ...


Prim vs Kruskal

  • 시간 복잡도(V: 정점의 수, E: 간선의 수)
    • Prim: O(N^2)
    • Kruskal: O(ElogE)
  • 사용 경우
    • Prim: N^2 < ElogE
    • Kruskal: N^2 > ElogE
    • 일반적으로 ElogE의 값이 더 작은 경우가 대부분이어서 Kruskal이 많이 쓰인다.


Leave a comment