DL&ML/code.data.tips

torch amp mixed precision (autocast, GradScaler)

식피두 2021. 4. 13. 17:14

1.5 버전 부터인가 nvidia의 amp모듈이 torch 기본 모듈로 자리잡았다.

 

amp의 Mixed Precision 기능을 이용하면

float16으로 Type Casting 되는 것이 빠른 연산(Linear Layer, Conv Layer etc.)은

float16으로 변환해서 연산을 수행하는 것이 가능해진다. (계산 정확도를 유지하는 선에서)

 

https://pytorch.org/docs/stable/amp.html

 

torch amp 모듈은 autocasting을 위한 모듈을 제공하며,

아래 예시 코드에서 확인할 수 있듯 그 이름도 autocast이다. 

 

with 문과 함께 선언해서 사용하면

그 안에 선언되는 토치 연산들은 mixed precision으로 실행이 된다.

model의 forward 연산과, loss 계산 연산을 with문 아래에 위치 시키자.

 

autocast와 함께, GradScaler를 같이 사용해주어야 하는데,

스케일러의 용도는 다음과 같다.

 

foward-pass에서 float16으로 계산된 연산 결과는

backward-pass의 결과로 float16이 생성되는데,

만약 그래디언트가 너무 작다면, float16으로 표현될 수 없어 underflow가 일어나버릴 수 있다.

 

따라서 이런 일을 방지하기 위해

backward시 float32범위로 스케일을 해줄 필요가 있는 것!

 

뿐만 아니라, 옵티마이저가 파라미터를 업데이트하기 전에 원래 스케일대로 복구를 해 놔야

스케일 정도가 러닝 레잇에 영향 끼치는 것을 막을 수 있다.

 

아래는 내가 참고하려고 만든

더러운 예시.. 원래 문서의 예시를 참고하자.

import torch.nn.utils as torch_utils

from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler

model = 모델 초기화
optimizer = 옵티마이저(model.parameters(), ...)

# 생략 ...

scaler = GradScaler()

for step, batch in enumerate(tqdm(train_data_loader, desc="Train", ncols=80)):

    model.zero_grad()

    with autocast():
        logits = model(
            input_ids=ids,
            token_type_ids=token_type_ids,
            attention_mask=mask
        )

        loss = torch.nn.CrossEntropyLoss(weight=ce_weights)(
            logits, rel_type
        )  # logits => (batch * class_num), rel_type => (batch * 1)
        loss /= iters_to_accumulate

    scaler.scale(loss).backward()

    if (step + 1) % iters_to_accumulate == 0:
        torch_utils.clip_grad_norm_(
            model.parameters(),
            1e8,
        )

        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

 

공식 문서

 

Automatic Mixed Precision package - torch.cuda.amp — PyTorch 1.8.1 documentation

The following lists describe the behavior of eligible ops in autocast-enabled regions. These ops always go through autocasting whether they are invoked as part of a torch.nn.Module, as a function, or as a torch.Tensor method. If functions are exposed in mu

pytorch.org

 

Automatic Mixed Precision examples — PyTorch 1.8.1 documentation

Shortcuts

pytorch.org

 

pytorch.org/docs/stable/amp.html