DL&ML/papers

ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

식피두 2021. 4. 16. 01:28

arxiv.org/abs/2003.10555

 

Electra 모델은 어떻게 학습이 되는지 알아보자.

ELECTRA?

Masked Language Modeling(MLM) pre-training 방법은 입력 일부를 [MASK] 토큰으로 변경해버린 뒤 원래 토큰을 복원하는 식으로 학습을 한다.

 

그런데 이게 과연 효율적인가? 라는 의문에서 Electra의 아이디어가 나왔다.

 

마스킹을 할 때 15% 정도의 확률로 선택을 하고, 마스킹 된 것을 원본으로 복원하는 것을 학습하는데, 하나의 Example 당 15% 토큰만 학습에 기여하기 때문에 계산 효율적이지 못하다.

 

 

Electra에선 Replaced Token Detection (RTD) 방식의 pre-training 방법을 제안한다.

단순히 특정 토큰을 [MASK]로 마스킹 해버리는 것이 아니라, 그럴듯한 단어로 바꿔버리는 것!

 

Generator작은 크기(작은 크기로 둬야 그럴듯 하게 실수를 하니까?)의 MLM를 두고,

[MASK] 표시가 된 입력을 넣으면 그럴듯한 단어로 바뀌어 출력 되는데

이 출력을 Discriminator에 넣어,  '모든 토큰에 대해' 바뀌었는지, 안바뀌었는지 여부를 분류한다.

 

기존의 MLM은 학습 단계에서 사용하는 [MASK] 토큰이

다운 스트림 태스크에서 fine-tuning 될 때는 등장하지 않아

네트워크 입력의 mis-match가 발생하는 문제가 있었지만, Electra는 그런 문제가 없다.

 

어쨌든 RTD의 장점은 바로,

모델이 모든 입력 토큰으로 부터 knowledge를 (빠르고, 효율적으로) 배울 수 있다는 것이다.

 

Generator

제너레이터는 MLM 으로 학습이 되며, 주어진 입력 x=[x1, x2, ... , xn]에 대해

ceil(n*0.15) 개 만큼 [MASK]을 하고, 마스킹 된 토큰이 원래 무엇이었는지를 학습한다.

그리고, 마스킹 된 부분이 원래 뭐였는지 복구된 문장을 currupted example이라고 하자.

Discriminator

generator에 의해 생성된 currupted example을 입력으로 받아 각 토큰이 변형되었는지 아닌지를 학습한다.

 

각각의 로스 함수는 다음과 같이 정의 된다.