def kl_divergence(p, q):
return np.sum(np.where(p != 0, p * np.log(p / q), 0))
if __name__ == '__main__':
x = np.arange(-10, 10, 0.001)
p = norm.pdf(x, 0, 2)
q = norm.pdf(x, 2, 2)
print(kl_divergence(p, q))
'DL&ML > code.data.tips' 카테고리의 다른 글
분류 문제 관련 torch loss (BCEWithLogitsLoss, CrossEntropyLoss, LogSoftmax, NLLLoss) (0) | 2021.04.14 |
---|---|
torch amp mixed precision (autocast, GradScaler) (0) | 2021.04.13 |
Focal Loss (2) | 2021.04.13 |
TACRED dataset 관계 분류 태스크 (0) | 2021.04.12 |
Weighted Cross Entropy (0) | 2021.04.06 |