Source code for LibMTL.loss

import torch, time
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

[docs]class AbsLoss(object): r"""An abstract class for loss functions. """ def __init__(self): self.record = [] self.bs = []
[docs] def compute_loss(self, pred, gt): r"""Calculate the loss. Args: pred (torch.Tensor): The prediction tensor. gt (torch.Tensor): The ground-truth tensor. Return: torch.Tensor: The loss. """ pass
def _update_loss(self, pred, gt): loss = self.compute_loss(pred, gt) self.record.append(loss.item()) self.bs.append(pred.size()[0]) return loss def _average_loss(self): record = np.array(self.record) bs = np.array(self.bs) return (record*bs).sum()/bs.sum() def _reinit(self): self.record = [] self.bs = []
[docs]class CELoss(AbsLoss): r"""The cross-entropy loss function. """ def __init__(self): super(CELoss, self).__init__() self.loss_fn = nn.CrossEntropyLoss()
[docs] def compute_loss(self, pred, gt): r""" """ loss = self.loss_fn(pred, gt) return loss
[docs]class KLDivLoss(AbsLoss): r"""The Kullback-Leibler divergence loss function. """ def __init__(self): super(KLDivLoss, self).__init__() self.loss_fn = nn.KLDivLoss()
[docs] def compute_loss(self, pred, gt): r""" """ loss = self.loss_fn(pred, gt) return loss
[docs]class L1Loss(AbsLoss): r"""The Mean Absolute Error (MAE) loss function. """ def __init__(self): super(L1Loss, self).__init__() self.loss_fn = nn.L1Loss()
[docs] def compute_loss(self, pred, gt): r""" """ loss = self.loss_fn(pred, gt) return loss
[docs]class MSELoss(AbsLoss): r"""The Mean Squared Error (MSE) loss function. """ def __init__(self): super(MSELoss, self).__init__() self.loss_fn = nn.MSELoss()
[docs] def compute_loss(self, pred, gt): r""" """ loss = self.loss_fn(pred, gt) return loss