DL&ML/papers

Knowledge Distillation: A Survey

식피두 2021. 4. 28. 01:30

arxiv.org/pdf/2006.05525.pdf

 

모델 경량화 방법인 Knowledge Distillation (이하 KD) 서베이 논문.

 

KD가 무엇으로 구성되고 어떻게 학습이 이루어지는지에 관한 것들을 정리해보고자 한다.

 

딥러닝 모델을 한정된 자원을 가진 모바일 디바이스로 배포하고 싶다면 모델의 경량화는 필수다.

이 때 KD를 이용하면 모델을 압축시킬 수 있을 뿐만 아니라 추론 속도도 가속시킬 수 있다.

 

딥러닝 기반의 실서비스를 구성할 때도 모델 경량화기법이 유용하게 활용될 수 있다.

모델 Compression / Acceleration 방법

  • Parameter Pruning / Sharing
  • Low-rank factorization
  • Knowledge Distillation
  • 등등

Knowledge Distillation

작은 모델의 Student, 큰 모델의 Teacher 모델로 구성되어 (Capacity Gap)

Teacher의 Knowledge를 Student에게 주입시킨다.

 

 

KD는 다음 세 개 요소로 구성 된다.

  • Knowledge
  • Distillation Algorithm
  • Teacher Student Architecture

 

아래는 바닐라 KD 모델의 구조도이다.

Student 모델은 실제 정답에 대해서 학습하는 동시에

Teacher에 의해 생성된 Soft Targets들에 대해서도 학습한다. (Cross Entropy)

 

bench mark model of vanilla KD

 

Knowledge?

KD를 통해서 Student에게 주입 시키고자 하는 Knowledge는 3가지 정도로 구분 된다.

 

Response/Feature/Relation Based Knowledge

 

Response-Based Knowledge

  • Teacher 모델의 마지막 예측 값 (logits)
  • Teacher 모델의 확률 분포를 soft targets 삼아 Student에게 학습 시킬 수
    • 두 분포를 KL-Divergence loss를 이용해 학습

Responsed-Based KD

 

Feature-Based Knowledge

  • 뉴럴넷의 intermiediate 레이어에 대한 Knowledge를 주입
    • hints 라고도 표현함

Feature-Based Knowledge

 

Relation-Based Knowledge

직관적으로 와닿지는 않는 방법이다. 실험 부분을 참조해봐도 자주 쓰이지는 않는 것 처럼 보인다.

  • feature들 간의 relation을 Knowledge로서 학습
  • inner product btw. features from 2 layers

 

Distillation Scheme

Teacher와 Student를 학습하기 위한 학습 스킴에도 세 가지가 있다.

 

Offline 방법은 Teacher 모델을 프리트레인 시켜 놓고 Student에게 KD를 적용한다. (2-phase)

Online 방법은 Teacher와 Stduent를 동시에 학습 시키거나, 번갈아 가면서 학습 시킨다. (1-phase?)

Self Distillation은 깊은 레이어의 표현을 얕은 레이어로 주입시킨다.

 

 

Teacher-Student Architecture

Student를 Teacher 보다 작게 만들 때에는

depth/width를 작게 할지, 적은 레이어를 쓸지,

precision에 제한을 둘지(quantization) 등의 고민이 필요하다.

여기서 NAS(Network Architecture Search) 기법을 활용하기도...

 

Distillation Algorithm

  • 가장 간단한 방법은? 이미 언급 된 것 처럼
    • Teacher-Student간의 Knowledge를 직접 매치시켜 학습 시키는 방법이다.
      • Reponse를 비교하든, Feature를 비교하든...
      • 분류 문제라면 CE 혹은 KL-divergence Loss를 써서...
  • 그 외에도 다양한 기법이 시도되고 있다.
    • GAN을 이용해서 synthetic data를 생성(hard example)해서 학습
    • 여러가지 Teacher와 함께 학습
      • 각각의 pair(T_1~N, S)로 학습을 하거나
      • Teacher들의 출력을 평균내어 averaged logits과 비교하여 학습하거나
    • Data-Free Distillation이라고 해서, 데이터 없이 KD를 하는 방법
    • Quantization Distillation
      • 네트워크의 precision을 낮추기
      • High precision teacher & Low precision student

활용

  • BERT의 CLS 토큰 임베딩을 hint로 삼아 Student에게 Transfer
    • 문장 분류, 매칭, MRC에서 유용할 수
  • two-stage transformer로 KD
    • TinyBERT
  • BERT => BiLSTM