DL&ML/papers

DistilBERT (a distilled version of BERT: smaller, faster, cheaper and lighter)

식피두 2021. 4. 23. 00:31

arxiv.org/abs/1910.01108

 

DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter

As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP), operating these large models in on-the-edge and/or under constrained computational training or inference budgets remains challenging. In t

arxiv.org

Knowledge Distillation에 대해 훑어 보고 있는데, KD는 먼 곳에 있지 않았다.

 

BERT에 KD를 적용한 게 DistilBERT...

 

KD에 관해서는 서베이 논문을 훑어보고 있는데, 따로 정리할 예정이다.

그 이전에 당장 버트에는 KD가 어떻게 활용되었는지 궁금했다.

 

여기서 중요한 것은, 특정 태스크 전용으로 KD를 적용하는게 아니라 (QA model, STS, ...)

Pre-Training 단계에서 부터 KD를 적용해서

General Purpose Language Representation Model을 만들 수 있다는 점!

 

뭐, 이렇게 해서 기존 BERT 대비 사이즈는 40%,

NLU 능력은 97% 유지, 속도는 60% 빨라 졌다고 함.

latency에 민감한 서비스를 구성할 때는 유용하게 활용될 방법이다.

학습 방법 (Knowledge Distillation)

KD는 모델 압축 기술 중 하나이며,

큰 모델(Teacher), 작은 모델(Student)을 두어 학생이 선생의 동작 방식을 배울 수 있도록 한다.

(cf. ALBERT는 embedding을 더 작게 분해하고, layer간의 weight을 공유함으로써 모델을 압축함)

 

학생은 선생의 'soft target probability'를 배운다. 

 

t_i는 선생의 출력 확률 분포

 

기존의 CrossEntropy를 이용한 학습은

모델의 예측 확률 분포를 정답 one-hot 분포에 맞추어 (정답 위치 확률 최대화) 학습하게 된다.

학습 셋에 잘 피팅이 된 모델이라면 특정 클래스 확률은 높고 나머지는 거의 zero 에 가까운 확률 분포를 출력하게 된다.

이 때 모델의 일반화(Generalization)능력에 기여하는 부분은 바로 'near-zero' 부분 이라고 논문에서 언급하고 있다.

 

따라서 BERT를 KD할 때,

선생 모델이 출력하는 확률 분포 자체를 배움으로써

학생 모델이 자신 보다 복잡한 모델들만이 배울 수 있는 signal 또한 함께 배울 수 있다.

 

softmax-temparature

 

여기서 T를 도입하면 분포의 smoothness를 조정할 수 있다. 

T는 학습 과정에서 학생/선생 모두에게 적용되며 추론시엔 제외 시킨다.

 

Final Training Objective

위의 내용을 종합하여 최종 Training Objective를 정의하면 다음과 같다.

 

Final Training Objective = Distillation Loss(CE) + Masked Language Modeling Loss + Cosine Embedding Loss

 

(마지막 코사인 임베딩 로스는 학생/선생의 히든 벡터가 바라보는 방향을 일치 시켜주는데 도움을 주는 로스다)

 

Detail

  • Student 모델은 BERT에서 token-type embedding 및 pooler를 제거 + 레이어 개수 1/2(?) 을 줄인 버전의 모델이다.
  • Student의 weight 초깃값은 Teacher 모델의 weight을 이용하여 초기화 함
    • 레이어 개수를 절반으로 줄였으므로 동일 위치 레이어 + 인접 레이어 중 한 레이어 weight을 취한 듯
  • 배치는 4k 로 구성 되었고, dynamic masking + NSP objective 로 학습 되었다.