기계학습은 사람이 명시적으로 프로그래밍하지 않고도 컴퓨터가 학습할 수 있는 능력을 제공하는 인공 지능을 말한다. 기본적으로 기계학습은 알고리즘을 사용하여 컴퓨터가 데이터를 학습하고, 패턴을 찾고, 예측을 할 수 있는 기능을 제공한다. 그러나 문제는 기존의 기계학습 모델은 데이터가 훈련된 데이터와 유사할 것을 전제로 하지만 항상 그렇지 않은 문제가 있다. 이를 보안하기 위해 나온 방법이 Transfer Learning이다.
Transfer Learning의 문제점은 다른 Task를 학습하다 보니, 원래 학습시킨 Task에 대한 정보를 잃어버린다는 것이다. 각 Task를 위한 모델을 따로 저장하는 방법이 있겠지만, 하나의 모델로 여러 Task를 처리하려는 Approach에서는 치명적인 약점이다. 이 문제점을 Catastrophic forgetting과 Semantic Drift라고 부르며, 의미가 변화하여 기존의 문제를 잃어버리게 되는 것이다. 각 weight들이 해당 Task를 배우는데 정확하게 어떤 Correlation이 있는지 모르는 Deep Learning에서 Fine Tuning을 위해 weight를 바꾸면, 기존 Task를 잃어버리는 것은 당연한 것이다.
Continual Learning목표
하나의 Model을 조금씩 업그레이드 시키면서, 여러 Task를 처리할 수 있도록 만드는 것이 목표이며, Neural Network를 업데이트하는 방식으로 두 가지 접근방법으로 나뉠 수 있다. ‘기존 Neural Network의 구조를 바꾸지 않고 Weight를 Fine Tuning 하는 방식’과 ‘기존 Neural Network의 구조를 조금 수정하는 방법’이다.
Elastic Weights Consolidation
첫 번째 방식 중 대표적인 알고리즘으로 Deep Mind에서 발표하였다.
Fine Tuning으로 weight를 섣불리 건드리면, 직전 Task를 잃어버리니, 직전 Model의중요한 weight를 업데이트하는 곳에서 regularization Term을 추가해서 조금만 수정하고, 나머지 weight들을 건드리는 알고리즘이다.
논문에 소개된 그림을 보면 L2 Regularization 과 No Penalty 알고리즘을 보면, 기존 Task A가 적은 Error를 갖는 구간을 벗어나 Task B로 이동하는 반면, EWC는 그 중간 지점을 교묘하게 잘 찾아가는 것을 볼 수 있다.
이때 ‘중요한 weight’를 어떻게 고를지에 대한 내용이 이 논문의 핵심 중 하나이다. 기존의 Task를 위한 모델과 새로 학습할 Task를 비교하는 과정이 EWC의 메인 Loss Function 안에 들어가 있다. F라는 함수는 Fisher Information Matrix로 어떤 Random Variable의 관측 값으로부터, 본포의 parameter에 대해 유추할 수 있는 정보의 양이다.
MNIST 손글씨와 패션 데이터를 EWC를 활용해 보았다. 이 둘은 서로 다른 데이터이지만 분류 레이블은 0~9로 같기 때문에 같은 학습 모델을 사용할 수 있다.
우선 손글씨 데이터를 학습한 다음 패션 데이터를 EWC를 적용하여 학습 후 손글씨와 패션 데이터를 분류하였을 때 손글씨 데이터가 분류가 되는지 확인해 보았다.
결과는 위의 표에 나온 결과처럼 패션 데이터를 EWC에 적용 후 손글씨 데이터를 분류하였을 때 약 79%의 정확도로 분류되었고, EWC를 적용하지 않고 패션 데이터를 학습 시킨 후 손글씨 데이터를 분류하였을 때는 약 27%의 정확도로 분류 되었다.