DL&ML/code.data.tips

Focal Loss

식피두 2021. 4. 13. 16:23

Object Detection에서 Background / Foreground Class의 불균형 문제를

로스 함수로 해결하기 위해 제안된 focal loss

 

이걸 클래스 분균형이 심한 일반 분류 문제에도

적용할 수 있을 것 같아서 살펴보았다..

 

핵심 아이디어는 다음과 같다.

모델 입장에서 쉽다고 판단하는 example에 대해서

모델의 출력 확률(confidence) Pt가 높게 나올테니

 

(1-Pt)^gamma를 CE에 추가해줌으로써

높은 확신에 대해 패널티를 주는 방법

 

반대로 어려워하고 있는 example에 대해선

Pt가 낮게 나올테니 (1-Pt)^gamma가 상대적으로 높게 나올 것!

 

gamma가 높을 수록 (1-Pt)가 작을 수록 더 작아진다

(확신이 높은 example은 패널티를 더 받음)

 

일반적인 분류문제에 적용하고 싶다면?

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):nn.CrossEntropyLoss()
    
        ce_loss = nn.CrossEntropyLoss()(inputs, targets, reduction='none')

        pt = torch.exp(-ce_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

 

discuss.pytorch.org/t/focal-loss-for-imbalanced-multi-class-classification-in-pytorch/61289

 

Focal loss for imbalanced multi class classification in Pytorch

I want an example code for Focal loss in PyTorch for a model with three class prediction. My model outputs 3 probabilities. Sentiment_LSTM( (embedding): Embedding(19612, 400) (lstm): LSTM(400, 512, num_layers=2, batch_first=True, dropout=0.5) (dropout): Dr

discuss.pytorch.org

m.blog.naver.com/PostView.nhn?blogId=sogangori&logNo=221087066947&proxyReferer=https:%2F%2Fwww.google.com%2F

 

Focal Loss for Dense Object Detection

Focal Loss for Dense Object Detection Tsung-Yi Lin Priya Goyal Ross Girshick Kaiming H...

blog.naver.com