[Baekjoon] 14500. 테트로미노

8 minute read

문제 설명

문제

폴리오미노란 크기가 1×1인 정사각형을 여러 개 이어서 붙인 도형이며, 다음과 같은 조건을 만족해야 한다.

  • 정사각형은 서로 겹치면 안 된다.
  • 도형은 모두 연결되어 있어야 한다.
  • 정사각형의 변끼리 연결되어 있어야 한다. 즉, 꼭짓점과 꼭짓점만 맞닿아 있으면 안 된다.

정사각형 4개를 이어 붙인 폴리오미노는 테트로미노라고 하며, 다음과 같은 5가지가 있다.

img

아름이는 크기가 N×M인 종이 위에 테트로미노 하나를 놓으려고 한다. 종이는 1×1 크기의 칸으로 나누어져 있으며, 각각의 칸에는 정수가 하나 쓰여 있다.

테트로미노 하나를 적절히 놓아서 테트로미노가 놓인 칸에 쓰여 있는 수들의 합을 최대로 하는 프로그램을 작성하시오.

테트로미노는 반드시 한 정사각형이 정확히 하나의 칸을 포함하도록 놓아야 하며, 회전이나 대칭을 시켜도 된다.

입력

첫째 줄에 종이의 세로 크기 N과 가로 크기 M이 주어진다. (4 ≤ N, M ≤ 500)

둘째 줄부터 N개의 줄에 종이에 쓰여 있는 수가 주어진다. i번째 줄의 j번째 수는 위에서부터 i번째 칸, 왼쪽에서부터 j번째 칸에 쓰여 있는 수이다. 입력으로 주어지는 수는 1,000을 넘지 않는 자연수이다.

출력

첫째 줄에 테트로미노가 놓인 칸에 쓰인 수들의 합의 최댓값을 출력한다.

예제 입력 1

5 5
1 2 3 4 5
5 4 3 2 1
2 3 4 5 6
6 5 4 3 2
1 2 1 2 1

예제 출력 1

19

예제 입력 2

4 5
1 2 3 4 5
1 2 3 4 5
1 2 3 4 5
1 2 3 4 5

예제 출력 2

20

예제 입력 3

4 10
1 2 1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1 2 1

예제 출력 3

7

출처

알고리즘 분류


문제 풀이

# Implementation # Bruteforce # Backtracking # Graph # DFS/BFS

Bruteforce, Graph, Backtracking 등 다양한 개념들이 복합적으로 녹아있는 구현 문제입니다.


풀이 과정

여러가지 풀이법을 시도해보다 어렵게 성공한 문제입니다. 풀릴 듯 풀릴 듯 잘 안풀리는 문제였는데, 이 문제를 통해 그래프 탐색 개념에 대해 다시 한 번 돌아보고 정립할 수 있었던 것 같습니다.

자세한 풀이 과정은 아래 코드를 보며 이야기하도록 하겠습니다.


전체 코드

😒 1번 풀이(브루트포스 정공법): 코드 작성 안 함

문제에서 주어지는 종이의 크기가 500x500으로 크지 않습니다. 그리고 문제의 스타일을 봤을 때, 브루트포스 식 접근이 적절하겠다는 생각이 들었습니다.

맨 처음에 시도한 브루트포스 식 풀이는 다음과 같습니다.

  • blocks_2x4와 blocks_4x2의 크기 8짜리 리스트를 원소로 갖는 리스트를 만들어서2x4짜리 직사각형에 들어갈 수 있는 수 리스트를 모두 만든다.

    '''
    blocks_2x4
    1 2 3 4
    5 6 7 8
    blocks_4x2
    1 2
    3 4
    5 6
    7 8
    '''
    
  • 테트로미노 모양에 맞게 리스트의 각 원소(크기 8짜리 리스트)에서 4개 수를 제외시키며 가장 큰 수를 구한다.

테트로미노 모양은 모두 2x4 크기의 직사각형 내에 들어오는 모양이기 때문에, 위와 같은 풀이법을 생각했습니다. 실제로 이와 같이 푼 분들도 있는데, 저는 이 풀이법이 너무 일반적이지 않고 하드 코딩스럽다고 생각해서 다른 풀이법을 생각해보았습니다.

😢 2번 풀이(그래프 자료구조, 다익스트라, 그리디): 틀렸습니다

그런데 생각을 해보니, 입력으로 주어진 2차원 배열은 우리가 그래프 문제에서 자주 사용하던 모양입니다. 따라서 그래프 탐색 알고리즘을 사용하면 되겠다는 생각을 합니다. (왜 이 생각을 빨리 못 했을까요..?😢)

그리고 어떤 그래프 탐색 알고리즘을 쓸까 생각하다가, 다익스트라를 이용하기로 합니다. 그 이유는 하나의 노드에서 주변 3개의 노드를 추가적으로 선택할 때, 숫자가 가장 높은 블록을 그리디하게 선택하면 되겠다는 생각을 했기 때문입니다. 따라서 힙을 이용해 최적 값을 먼저 탐색하는 다익스트라 알고리즘을 선택하고 구현합니다.

N, M = map(int, input().split()) # 세로, 가로
MAP = [list(map(int, input().split())) for _ in range(N)]

from heapq import heappop, heappush
didj = [(-1,0),(1,0),(0,-1),(0,1)]
def dijkstra(i,j):
    heap = [(-MAP[i][j],(i,j))] # -1 * 블록에 적힌 수(최대힙), 블록 좌표, path 길이
    visited = set() # 4개만 탐색하니 방문한 좌표를 넣는 식으로 구현(메모리 절약)
    ret = 0
    for _ in range(4):
        val, (ci,cj) = heappop(heap)
        ret += -val
        visited.add((ci,cj))
        for di, dj in didj:
            ni, nj = ci+di, cj+dj
            if 0 <= ni < N and 0 <= nj < M and (ni,nj) not in visited:
                heappush(heap, (-MAP[ni][nj],(ni,nj)))
    return ret

ans = 0
for i in range(N):
    for j in range(M):
        ans = max(ans,dijkstra(i,j))
print(ans)

하지만, 이 풀이는 틀린 풀이입니다. 애초에 그리디한 탐색으로 정답을 보장할 수 없습니다. 예를 들어 (7,7,3,6,10)과 같은 모양이 있을 때, 정답은 (7,3,10,6)인데 탐색은 (7,7,3,6)으로 이루어집니다. 가정 자체가 틀렸던 것이죠.

그런데 저는 어차피 모든 노드에 대해서 탐색하기 때문에, 이번 노드에서 최댓값을 놓쳤더라도 다른 노드에서 탐색하는 과정에서 최댓값을 찾을 수 있을거라 생각했는데, 그렇지 않나봅니다 😢 이 풀이는 기본 예제 입력은 물론 질문 검색에 있는 모든 반례들도 다 통과해서, 틀린 이유를 생각해내는 것이 쉽지 않았습니다..!!

😢 3번 풀이(backtracking): 틀렸습니다

결국 필요한 것은 dfs(backtracking)입니다. 위 풀이가 틀린 것으로 보아, 각 노드에 대해 제일 큰 수를 가지는 인접 노드들만 방문하는 것이 아니라, 가능한 모든 경우의 수를 방문해봐야 합니다. 그래프 탐색에서 dfs는 모든 경우의 수 탐색, bfs는 최단 경로 탐색에 주로 쓰이죠.

따라서 dfs를 쓰고, 이의 활용인 backtracking 알고리즘으로 답을 찾아봅시다.

N, M = map(int, input().split()) # 세로, 가로
MAP = [list(map(int, input().split())) for _ in range(N)]

didj = [(-1,0),(1,0),(0,-1),(0,1)]
def backtracking(i,j,n,max_val,val):
    if n == 4:
        return max(max_val,val)
    for di, dj in didj:
        ni, nj = i+di, j+dj
        if 0 <= ni < N and 0 <= nj < M and not visited[ni][nj]:
            visited[ni][nj] = True
            max_val = backtracking(ni,nj,n+1,max_val,val+MAP[ni][nj])
            visited[ni][nj] = False
    return max_val

ans = 0
for i in range(N):
    for j in range(M):
        ans = max(ans,dijkstra(i,j))
print(ans)

하지만 이 풀이도 틀렸습니다!!! 그 이유는 dfs는 T자 모양 탐색을 할 수 없기 때문입니다.

그러면 어떡해야 하지..? 모든 그래프 탐색법을 써봤는데 왜 안되지..? 라고 생각했습니다.

😊 4번 풀이(변형 backtracking): 시간초과

그런데, 조금만 생각해보면 T자 모양 탐색을 할 수 있는 방법이 있습니다. 바로 n=2일 때(즉 2개의 블록을 탐색 중일 때), 새로운 블록에서 다음 블록을 탐색하는 것이 아니라 다시 기존 블록에서 탐색하게 만들면 T자 모양 탐색이 가능해집니다.

N, M = map(int, input().split()) # 세로, 가로
MAP = [list(map(int, input().split())) for _ in range(N)]

didj = [(-1,0),(1,0),(0,-1),(0,1)]
def backtracking(i,j,n,max_total,total):
    if n == 4:
        return max(max_total,total)
    for di, dj in didj: 
        ni, nj = i+di, j+dj
        if 0 <= ni < N and 0 <= nj < M and not visited[ni][nj]:
            if n == 2: # T 모양 탐색
                visited[ni][nj] = True
                max_total = backtracking(i,j,n+1,max_total,total+MAP[ni][nj])
                visited[ni][nj] = False
            visited[ni][nj] = True
            max_total = backtracking(ni,nj,n+1,max_total,total+MAP[ni][nj])
            visited[ni][nj] = False
    return max_total

ans = 0
visited = [[False for _ in range(M)] for _ in range(N)]
for i in range(N):
    for j in range(M):
        visited[i][j] = True
        ans = max(ans,backtracking(i,j,1,0,MAP[i][j]))
        visited[i][j] = False
print(ans)

오!! 이제 ‘틀렸습니다’가 나오지 않습니다!! 하지만 시간 초과가 발생했어요!! ㅋㅋㅋㅋㅋ 정말 쉽게 놓아주지 않는 문제네요.

이제 우리의 숙제는 backtracking의 시간을 줄이는 것입니다.

😁 5번 풀이(변형된 backtracking + 가지치기): 맞았습니다!!(6276ms, 336ms, 292ms)

Backtracking에서 시간을 줄일 수 있는 방법은 가지치기(Pruning)입니다. 백트래킹 알고리즘의 핵심 중 하나는 유망성 검사입니다. 백트래킹 알고리즘 자체가 오래 걸리는 알고리즘이다 보니, 가능성이 없는 곳은 더 이상 탐색하지 않는 것이죠.

저는 이 핵심을 잊고 있었습니다. 위 코드에 적절한 가지치기 코드를 추가해주면 시간초과를 극복할 수 있습니다. 근데 여기서 위에 6276ms, 336ms, 292ms라고 적은 이유는 가지치기의 기준과 약간의 코드 작성 방법에 따라 시간이 크게 달라지기 때문입니다.


처음으로는 backtracking 함수의 인자인 max_total과의 비교를 했습니다. 아래 코드가 6276ms의 실행 시간을 기록합니다.

N, M = map(int, input().split()) # 세로, 가로
MAP = [list(map(int, input().split())) for _ in range(N)]
MAP_maxval = max(map(max,MAP))

def backtracking(i,j,n,max_total,total):
    if max_total >= total + MAP_maxval*(4-n): # 가지치기 (max_total)
        return max_total
    if n == 4: # 완료조건
        return max(max_total,total)
    for di, dj in didj: 
        ni, nj = i+di, j+dj
        if 0 <= ni < N and 0 <= nj < M and not visited[ni][nj]:
            if n == 2: # T 모양 탐색
                visited[ni][nj] = True
                max_total = backtracking(i,j,n+1,max_total,total+MAP[ni][nj])
                visited[ni][nj] = False
            visited[ni][nj] = True
            max_total = backtracking(ni,nj,n+1,max_total,total+MAP[ni][nj])
            visited[ni][nj] = False
    return max_total

ans = 0
visited = [[False for _ in range(M)] for _ in range(N)]
for i in range(N):
    for j in range(M):
        visited[i][j] = True
        ans = max(ans,backtracking(i,j,1,0,MAP[i][j]))
        visited[i][j] = False
print(ans)

근데 결국, 우리가 구할 정답은 ans에 있기 때문에 ans와 비교해줍시다. 그러면 336ms의 실행 시간을 보입니다.

N, M = map(int, input().split()) # 세로, 가로
MAP = [list(map(int, input().split())) for _ in range(N)]
MAP_maxval = max(map(max,MAP))

didj = [(-1,0),(1,0),(0,-1),(0,1)]
def backtracking(i,j,n,max_total,total):
    if ans >= total + MAP_maxval*(4-n): # 가지치기 (ans)
        return max_total
    if n == 4: # 완료조건
        return max(max_total,total)
    for di, dj in didj: 
        ni, nj = i+di, j+dj
        if 0 <= ni < N and 0 <= nj < M and not visited[ni][nj]:
            if n == 2: # T 모양 탐색
                visited[ni][nj] = True
                max_total = backtracking(i,j,n+1,max_total,total+MAP[ni][nj])
                visited[ni][nj] = False
            visited[ni][nj] = True
            max_total = backtracking(ni,nj,n+1,max_total,total+MAP[ni][nj])
            visited[ni][nj] = False
    return max_total

ans = 0
visited = [[False for _ in range(M)] for _ in range(N)]
for i in range(N):
    for j in range(M):
        visited[i][j] = True
        ans = max(ans,backtracking(i,j,1,0,MAP[i][j]))
        visited[i][j] = False
print(ans)

근데, 이럴 거면 굳이 계속해서 max_total 인자를 들고다니며 return해 줄 필요가 없습니다. 따라서 그냥 ans를 global로 선언하고 갱신해줍니다. 사실 저는 global 선언하는 것을 안 좋아하기는 하는데, 시간 비교를 위해 한 번 해봤습니다.

최종 코드는 아래와 같고, 실행 시간은 292ms입니다.

N, M = map(int, input().split()) # 세로, 가로
MAP = [list(map(int, input().split())) for _ in range(N)]
MAP_maxval = max(map(max,MAP))

didj = [(-1,0),(1,0),(0,-1),(0,1)]
def backtracking(i,j,n,total):
    global ans
    if ans >= total + MAP_maxval*(4-n): # 가지치기 (안할 시: 시간초과 -> max_total과 비교 시: 6276ms -> ans와 비교 시: 336ms)
        return
    if n == 4: # 완료조건
        ans = max(ans,total)
        return
    for di, dj in didj: 
        ni, nj = i+di, j+dj
        if 0 <= ni < N and 0 <= nj < M and not visited[ni][nj]:
            if n == 2: # T 모양 탐색
                visited[ni][nj] = True
                backtracking(i,j,n+1,total+MAP[ni][nj])
                visited[ni][nj] = False
            visited[ni][nj] = True
            backtracking(ni,nj,n+1,total+MAP[ni][nj])
            visited[ni][nj] = False
    return

ans = 0
visited = [[False for _ in range(M)] for _ in range(N)]
for i in range(N):
    for j in range(M):
        visited[i][j] = True
        backtracking(i,j,1,MAP[i][j])
        visited[i][j] = False
print(ans)


오래 걸린 문제지만, 그만큼 배운 것도 많았습니다.


배운 점

  • 입력의 크기가 작으면 브루트포스 알고리즘을 고려할 수 있다.

  • 입력으로 2차원 배열이 주어지면 그래프 탐색을 고려할 수 있다.

  • 그래프 탐색은 다음과 같이 나눌 수 있다.

    • 단순 탐색
      • dfs: 모든 경우 탐색, 백트래킹
      • bfs: 최단거리, 연결 요소 탐색
      • floyd-warshall: 모든 노드 간 최단 거리
    • 최적 탐색
      • dijkstra: 최단 거리 + 가중치 존재
      • bellman-ford: 최단 거리 + 음수 가중치 존재
  • 백트래킹 코드에 조건을 추가하면 T자 탐색이 가능하다.

  • 백트래킹 알고리즘의 핵심은 ‘유망성 검사(가지치기, pruning)’이다. 적절한 유망성 검사는 시간을 크게 단축시킨다.

Leave a comment