[AITech] 20220121 - RNN 기초
강의 복습 내용
RNN 기초
시퀀스 데이터
-
소리, 문자열, 주가 등의 데이터를 시퀀스 데이터로 분류한다.
-
시퀀스 데이터는 독립동등분포(i.i.d.) 가정을 위배하기 때문에 순서를 바꾸거나 과거 정보에 손실이 발생하면 데이터의 확률분포도 바뀐다.
-
따라서 이전 시퀀스에 대한 정보를 가지고 앞으로 발생할 데이터의 확률 분포를 계산해야 하며, 이를 위해 조건부 확률을 이용할 수 있다.
- 위 조건부 확률은 과거의 모든 정보를 이용하지만, 시퀀스 데이터를 분석할 때 과거의 모든 정보들이 필요한 것은 아니다.
- 어떤 시점까지의 과거의 정보를 이용할 지는 데이터/모델링에 따라 달라진다.
- 위 조건부 확률은 과거의 모든 정보를 이용하지만, 시퀀스 데이터를 분석할 때 과거의 모든 정보들이 필요한 것은 아니다.
-
시퀀스 데이터를 다루기 위해서는 길이가 가변적인 데이터를 다룰 수 있는 모델이 필요하다.
- 이를 해결하기 위해 특정 구간 _tau_만큼의 과거 정보만을 이용하고, 그보다 더 전의 정보들은 Ht라는 잠재변수로 인코딩해서 사용할 수 있다.
- 이렇게 함으로써 데이터의 길이를 고정할 수 있고, 과거의 모든 데이터를 활용하기 용이해진다.
-
이 잠재변수 Ht를 신경망을 통해 반복해서 사용하여 시퀀스 데이터의 패턴을 학습하는 모델이 RNN이다.
- 이를 해결하기 위해 특정 구간 _tau_만큼의 과거 정보만을 이용하고, 그보다 더 전의 정보들은 Ht라는 잠재변수로 인코딩해서 사용할 수 있다.
RNN(Recurrent Neural Network)
-
현재 정보만을 입력으로 사용하는 완전연결신경망은 과거의 정보를 다룰 수 없다.
-
RNN은 이전 순서의 잠재변수와 현재의 입력을 활용하여 모델링한다.
- W: t에 따라 불변/ X, H: t에 따라 가변
-
RNN의 역전파는 잠재변수의 연결그래프에 따라 순차적으로 계산한다. (맨 마지막 출력까지 계산한 후에 역전파)
- 이를 BPTT(Backpropagation Through Time)라 하며 RNN의 기본적인 역전파 방식이다.
-
BPTT를 통해 RNN의 가중치 행렬의 미분을 계산해보면 아래와 같이 미분의 곱으로 이루어진 항이 계산된다.
- 그 중 빨간색 네모 안의 항은 불안정해지기 쉽다.
- 이는 거듭된 값들의 곱으로 인해 값이 너무 커지거나(기울기 폭발) 너무 작아져(기울기 소실) 과거의 정보를 제대로 전달해주지 못하기 때문이다.
-
기울기 폭발/소실 문제를 해결하기 위해 역전파 과정에서 길이를 끊는 것이 필요하며, 이를 TBPTT(Truncated BPTT)라 한다.
-
여러가지 문제로 Vanilla RNN으로는 긴 시퀀스를 처리하는데 한계가 있고, 이를 해결하기 위해 LSTM이나 GRU와 같은 발전된 형태의 네트워크를 사용한다.
Leave a comment