[Baekjoon] 17626. Four Squares

2 minute read

문제 설명

문제

라그랑주는 1770년에 모든 자연수는 넷 혹은 그 이하의 제곱수의 합으로 표현할 수 있다고 증명하였다. 어떤 자연수는 복수의 방법으로 표현된다. 예를 들면, 26은 52과 12의 합이다; 또한 42 + 32 + 12으로 표현할 수도 있다. 역사적으로 암산의 명수들에게 공통적으로 주어지는 문제가 바로 자연수를 넷 혹은 그 이하의 제곱수 합으로 나타내라는 것이었다. 1900년대 초반에 한 암산가가 15663 = 1252 + 62 + 12 + 12라는 해를 구하는데 8초가 걸렸다는 보고가 있다. 좀 더 어려운 문제에 대해서는 56초가 걸렸다: 11339 = 1052 + 152 + 82 + 52.

자연수 n이 주어질 때, n을 최소 개수의 제곱수 합으로 표현하는 컴퓨터 프로그램을 작성하시오.

입력

입력은 표준입력을 사용한다. 입력은 자연수 n을 포함하는 한 줄로 구성된다. 여기서, 1 ≤ n ≤ 50,000이다.

출력

출력은 표준출력을 사용한다. 합이 n과 같게 되는 제곱수들의 최소 개수를 한 줄에 출력한다.

예제 입력 1

25

예제 출력 1

1

예제 입력 2

26

예제 출력 2

2

예제 입력 3

11339

예제 출력 3

3

예제 입력 4

34567

예제 출력 4

4

출처

ICPC > Regionals > Asia Pacific > Korea > Nationwide Internet Competition > Seoul Nationalwide Internet Competition 2019 H번

  • 데이터를 추가한 사람: tktj12

알고리즘 분류


문제 풀이

# DynamicProgramming # Bruteforce # Math

Dynamic ProgrammingBruteforce를 활용하는 Math 문제입니다.


풀이 과정

문제의 난이도는 실버 4이지만, 풀이법이 잘 떠오르지 않아 애먹은 문제입니다. 문제들을 보다 보면, 문제의 난이도와 상관없이 잘 풀리지 않거나 쉽게 잘 풀리는 문제들이 있는 것 같습니다. 난이도는 낮은데 잘 풀리지 않는 문제들은 아이디어가 떠오르지 않는 경우가 많으니, 문제들을 많이 풀어보고 많은 아이디어들을 알고 있는게 중요한 것 같습니다.

자세한 설명은 아래에서 하겠습니다.


전체 코드

😂 1번 풀이: 시간 초과

첫 번째로 떠오른 풀이입니다. dp[i]dp[j]와 dp[i-j]의 조합으로 나타내고, i//2 이하의 모든 수들에 대해 j를 for문으로 돌립니다. 이렇게 해도 정답은 구할 수 있지만, 시간 복잡도가 O(N^2)이기 때문에 효율적이지 못 한 풀이입니다.

N = int(input())
dp = [float('inf') for _ in range(N+1)]
dp[0], dp[1] = 0, 1
for i in range(2,N+1):
    if int(i**(1/2)) == i**(1/2):
        dp[i] = 1
        continue
    for j in range(1,i//2+1):
        dp[i] = min(dp[i],dp[j]+dp[i-j])
        if dp[i] == 2: break
print(dp[N])

😊 2번 풀이: 정답(PyPy3)

그래서 시간 복잡도를 줄여야 합니다. 생각해보면, dp[i]를 구하기 위한 조합에 제한을 둘 수 있다는 것을 알 수 있습니다. 바로 제곱수가 포함될 때 가장 적은 개수의 수들로 표현할 수 있게 됩니다. 자신의 수에서 그보다 작은 수의 제곱수를 뺀 것의 최소를 구하고, 거기에 한 개(제곱 수)를 더해주면 됩니다.

따라서 j를 i//2 이하의 모든 수들에 대해 for문을 돌리지 말고, root(i) 이하의 자연수들 중 제곱수가 포함된 조합으로 나타낼 수 있는 것들만 탐색합니다. 이렇게 함으로써 시간 복잡도를 O(NlogN)으로 줄일 수 있습니다.

N = int(input())
dp = [float('inf') for _ in range(N+1)]
dp[0], dp[1] = 0, 1
for i in range(N+1):
    for j in range(1,int(i**(1/2))+1):
        dp[i] = min(dp[i], dp[i-j**2]+1) # min(dp[i], dp[i-j**2]+dp[j**2])
print(dp[N])

참고로, Python3로는 시간 초과를 해결할 수 없어 PyPy3로 통과한 풀이입니다.


정리

  • 시간 복잡도를 O(N^2)->O(NlogN)으로 줄일 때 제곱 수를 활용할 수 있다는 것을 배웠습니다.

Leave a comment