상세 컨텐츠

본문 제목

Continual Learning

Technology/Tech Insight

by JunHyuk_Kim 2022. 12. 21. 16:56

본문

Continual Learning 이란

  • 기계학습은 사람이 명시적으로 프로그래밍하지 않고도 컴퓨터가 학습할 수 있는 능력을 제공하는 인공 지능을 말한다. 기본적으로 기계학습은 알고리즘을 사용하여 컴퓨터가 데이터를 학습하고, 패턴을 찾고, 예측을 할 수 있는 기능을 제공한다. 그러나 문제는 기존의 기계학습 모델은 데이터가 훈련된 데이터와 유사할 것을 전제로 하지만 항상 그렇지 않은 문제가 있다. 이를 보안하기 위해 나온 방법이 Transfer Learning이다.
  • Transfer Learning의 문제점은 다른 Task를 학습하다 보니, 원래 학습시킨 Task에 대한 정보를 잃어버린다는 것이다. 각 Task를 위한 모델을 따로 저장하는 방법이 있겠지만, 하나의 모델로 여러 Task를 처리하려는 Approach에서는 치명적인 약점이다. 이 문제점을 Catastrophic forgetting과 Semantic Drift라고 부르며, 의미가 변화하여 기존의 문제를 잃어버리게 되는 것이다. 각 weight들이 해당 Task를 배우는데 정확하게 어떤 Correlation이 있는지 모르는 Deep Learning에서 Fine Tuning을 위해 weight를 바꾸면, 기존 Task를 잃어버리는 것은 당연한 것이다.

Continual Learning 목표

  • 하나의 Model을 조금씩 업그레이드 시키면서, 여러 Task를 처리할 수 있도록 만드는 것이 목표이며, Neural Network를 업데이트하는 방식으로 두 가지 접근방법으로 나뉠 수 있다. ‘기존 Neural Network의 구조를 바꾸지 않고 Weight를 Fine Tuning 하는 방식’과 ‘기존 Neural Network의 구조를 조금 수정하는 방법’이다.

Elastic Weights Consolidation

  • 첫 번째 방식 중 대표적인 알고리즘으로 Deep Mind에서 발표하였다.
  • Fine Tuning으로 weight를 섣불리 건드리면, 직전 Task를 잃어버리니, 직전 Model의중요한 weight를 업데이트하는 곳에서 regularization Term을 추가해서 조금만 수정하고, 나머지 weight들을 건드리는 알고리즘이다.

  • 논문에 소개된 그림을 보면 L2 Regularization 과 No Penalty 알고리즘을 보면, 기존 Task A가 적은 Error를 갖는 구간을 벗어나 Task B로 이동하는 반면, EWC는 그 중간 지점을 교묘하게 잘 찾아가는 것을 볼 수 있다.
  • 이때 ‘중요한 weight’를 어떻게 고를지에 대한 내용이 이 논문의 핵심 중 하나이다. 기존의 Task를 위한 모델과 새로 학습할 Task를 비교하는 과정이 EWC의 메인 Loss Function 안에 들어가 있다. F라는 함수는 Fisher Information Matrix로 어떤 Random Variable의 관측 값으로부터, 본포의 parameter에 대해 유추할 수 있는 정보의 양이다.

MNIST 손글씨와 패션 데이터를 EWC를 활용해 보았다. 이 둘은 서로 다른 데이터이지만 분류 레이블은 0~9로 같기 때문에 같은 학습 모델을 사용할 수 있다.

우선 손글씨 데이터를 학습한 다음 패션 데이터를 EWC를 적용하여 학습 후 손글씨와 패션 데이터를 분류하였을 때 손글씨 데이터가 분류가 되는지 확인해 보았다.

각 데이터의 shape

결과는 위의 표에 나온 결과처럼 패션 데이터를 EWC에 적용 후 손글씨 데이터를 분류하였을 때 약 79%의 정확도로 분류되었고, EWC를 적용하지 않고 패션 데이터를 학습 시킨 후 손글씨 데이터를 분류하였을 때는 약 27%의 정확도로 분류 되었다. 

 

Reference