import torch, time
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
[docs]class AbsMetric(object):
r"""An abstract class for the performance metrics of a task.
Attributes:
record (list): A list of the metric scores in every iteration.
bs (list): A list of the number of data in every iteration.
"""
def __init__(self):
self.record = []
self.bs = []
@property
[docs] def update_fun(self, pred, gt):
r"""Calculate the metric scores in every iteration and update :attr:`record`.
Args:
pred (torch.Tensor): The prediction tensor.
gt (torch.Tensor): The ground-truth tensor.
"""
pass
@property
[docs] def score_fun(self):
r"""Calculate the final score (when an epoch ends).
Return:
list: A list of metric scores.
"""
pass
[docs] def reinit(self):
r"""Reset :attr:`record` and :attr:`bs` (when an epoch ends).
"""
self.record = []
self.bs = []
# accuracy
[docs]class AccMetric(AbsMetric):
r"""Calculate the accuracy.
"""
def __init__(self):
super(AccMetric, self).__init__()
[docs] def update_fun(self, pred, gt):
r"""
"""
pred = F.softmax(pred, dim=-1).max(-1)[1]
self.record.append(gt.eq(pred).sum().item())
self.bs.append(pred.size()[0])
[docs] def score_fun(self):
r"""
"""
return [(sum(self.record)/sum(self.bs))]
# L1 Error
[docs]class L1Metric(AbsMetric):
r"""Calculate the Mean Absolute Error (MAE).
"""
def __init__(self):
super(L1Metric, self).__init__()
[docs] def update_fun(self, pred, gt):
r"""
"""
abs_err = torch.abs(pred - gt)
self.record.append(abs_err.item())
self.bs.append(pred.size()[0])
[docs] def score_fun(self):
r"""
"""
records = np.array(self.record)
batch_size = np.array(self.bs)
return [(records*batch_size).sum()/(sum(batch_size))]