torch.nn.BCEWithLogitsLoss
이진 분류 문제를 풀 때 쓰는 BCEWithLogitsLoss
Sigmoid layer + BCELoss의 조합으로 구현되어 있다.
이미 최적화 되어 있으므로 별도로 구현하기 보다 갖다 쓰는게 낫다고 한다.
기본적으로는 'mean'으로 reduction 되어 스칼라 값을 리턴한다.
single label 이진 분류 문제 뿐만 아니라
multi-label 이진 분류 문제를 풀 때도 활용 가능하다.
single label 이진 분류 예시
logits = model(ids,
token_type_ids=token_type_ids,
attention_mask=mask,
ans_indices=ans_indices)
# |logits| = (batch, 1)
# |labels| = (batch,)
loss = torch.nn.BCEWithLogitsLoss()(logits, labels.view(-1, 1))
multi-label 이진 분류 예시
# 문장 내 정답 sub-string의 시작과 끝 토큰 인덱스를 찾으면서 (QA-model)
# 각 토큰이 sub-string에 속할 확률(multi-labels)을 출력하는 모델
start_logits, end_logits, seq_logits = model(ids,
token_type_ids=token_type_ids,
attention_mask=mask)
# |start_logits| = |end_logits| = (batch, seq_len)
# |seq_logits| = (batch, seq_len)
start_loss = torch.nn.CrossEntropyLoss()(start_logits, targets_start) # 시작 토큰 위치
end_loss = torch.nn.CrossEntropyLoss()(end_logits, targets_end) # 끝 토큰 위치
seq_loss = torch.nn.BCEWithLogitsLoss()(seq_logits, targets_seq)
# |target_seq| = (batch, seq_len)
pos_weight 파라미터를 통해 각 클래스별 recall/precision tradeoff를 조정할 수 있다.
1보다 크면 recall이 올라가고, 1보다 작을 경우 precision을 높일 수 있다.
하나의 클래스에 대해 100개의 postive, 300개의 negative examples가 있다면
300/100 = 3을 pos_weight에 줌으로써 loss 함수가 300개의 positive가 있는 것 처럼
동작하도록 해줄 수 있다.
target = torch.ones([10, 64], dtype=torch.float32)
output = torch.full([10, 64], 1.5) # 1.5로 채워진 10x64 행렬
pos_weight = torch.ones([64]) # 길이 64를 1로 채움
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5))
torch.nn.CrossEntropyLoss
CrossEntropyLoss는 torch.nn.LogSoftmax + torch.nn.NLLLoss 의 조합으로 구현되어 있다.
분류 모델의 마지막 레이어에 LogSoftmax를 씌우는 것을 선호한다면
NLLLoss를 써야 한다.
LogSoftmax는 다음의 수식을 구현 (특정 클래스 확률 값에 로그 취한 것)
NLLLoss (negative log likelihood loss)는 LogSoftmax의 결과를 단순히 -를 붙여 취합한 것
m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
input = torch.randn(3, 5, require_grad=True)
target = torch.tensor([1, 0, 4])
output = loss(m(input, target)
output.backward()
참고자료
'DL&ML > code.data.tips' 카테고리의 다른 글
모델 학습이 잘 되는지 여부를 판단할 수 있는 지표 (1) | 2021.04.27 |
---|---|
Kaggle TSE 2020 대회 top-solution 정리 (0) | 2021.04.15 |
torch amp mixed precision (autocast, GradScaler) (0) | 2021.04.13 |
Focal Loss (2) | 2021.04.13 |
KL divergence 구현 예시 (0) | 2021.04.13 |