DL&ML/code.data.tips

모델 학습이 잘 되는지 여부를 판단할 수 있는 지표

식피두 2021. 4. 27. 01:45

모델 학습이 잘 진행되는지 parameter normgradient norm을 활용할 수 있다. (김기현님 강의를 보다가 알게 됨...)

 

일반적으로(?)

  • parameter norm(L2)은 학습이 진행될 수록 커져야 한다.
    • 모델이 복잡해 지면서...
  • gradient norm(L2)는 점점 작아져야 한다.
    • grad norm이 크다? 그 만큼 많이 배우고 있다는 뜻. 학습이 진행되면서 점점 작아진다.
    • 학습 초반일 수록 틀리는 것이 많고, 많이 틀릴 수록 기울기가 가팔라짐.
@torch.no_grad()
def get_grad_norm(parameters, norm_type=2):
    parameters = list(filter(lambda p: p.grad is not None, parameters))

    total_norm = 0

    try:
        for p in parameters:
            total_norm += (p.grad.data**norm_type).sum()
        total_norm = total_norm ** (1. / norm_type)
    except Exception as e:
        print(e)

    return total_norm


@torch.no_grad()
def get_parameter_norm(parameters, norm_type=2):
    total_norm = 0

    try:
        for p in parameters:
            total_norm += (p.data**norm_type).sum()
        total_norm = total_norm ** (1. / norm_type)
    except Exception as e:
        print(e)

    return total_norm