Source code for LibMTL.trainer

import torch, os
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL._record import _PerformanceMeter
from LibMTL.utils import count_parameters

[docs]class Trainer(nn.Module): r'''A Multi-Task Learning Trainer. This is a unified and extensible training framework for multi-task learning. Args: task_dict (dict): A dictionary of name-information pairs of type (:class:`str`, :class:`dict`). \ The sub-dictionary for each task has four entries whose keywords are named **metrics**, \ **metrics_fn**, **loss_fn**, **weight** and each of them corresponds to a :class:`list`. The list of **metrics** has ``m`` strings, repersenting the name of ``m`` metrics \ for this task. The list of **metrics_fn** has two elements, i.e., the updating and score \ functions, meaning how to update thoes objectives in the training process and obtain the final \ scores, respectively. The list of **loss_fn** has ``m`` loss functions corresponding to each \ metric. The list of **weight** has ``m`` binary integers corresponding to each \ metric, where ``1`` means the higher the score is, the better the performance, \ ``0`` means the opposite. weighting (class): A weighting strategy class based on :class:`LibMTL.weighting.abstract_weighting.AbsWeighting`. architecture (class): An architecture class based on :class:`LibMTL.architecture.abstract_arch.AbsArchitecture`. encoder_class (class): A neural network class. decoders (dict): A dictionary of name-decoder pairs of type (:class:`str`, :class:`torch.nn.Module`). rep_grad (bool): If ``True``, the gradient of the representation for each task can be computed. multi_input (bool): Is ``True`` if each task has its own input data, otherwise is ``False``. optim_param (dict): A dictionary of configurations for the optimizier. scheduler_param (dict): A dictionary of configurations for learning rate scheduler. \ Set it to ``None`` if you do not use a learning rate scheduler. kwargs (dict): A dictionary of hyperparameters of weighting and architecture methods. .. note:: It is recommended to use :func:`LibMTL.config.prepare_args` to return the dictionaries of ``optim_param``, \ ``scheduler_param``, and ``kwargs``. Examples:: import torch.nn as nn from LibMTL import Trainer from LibMTL.loss import CE_loss_fn from LibMTL.metrics import acc_update_fun, acc_score_fun from LibMTL.weighting import EW from LibMTL.architecture import HPS from LibMTL.model import ResNet18 from LibMTL.config import prepare_args task_dict = {'A': {'metrics': ['Acc'], 'metrics_fn': [acc_update_fun, acc_score_fun], 'loss_fn': [CE_loss_fn], 'weight': [1]}} decoders = {'A': nn.Linear(512, 31)} # You can use command-line arguments and return configurations by ``prepare_args``. # kwargs, optim_param, scheduler_param = prepare_args(params) optim_param = {'optim': 'adam', 'lr': 1e-3, 'weight_decay': 1e-4} scheduler_param = {'scheduler': 'step'} kwargs = {'weight_args': {}, 'arch_args': {}} trainer = Trainer(task_dict=task_dict, weighting=EW, architecture=HPS, encoder_class=ResNet18, decoders=decoders, rep_grad=False, multi_input=False, optim_param=optim_param, scheduler_param=scheduler_param, **kwargs) ''' def __init__(self, task_dict, weighting, architecture, encoder_class, decoders, rep_grad, multi_input, optim_param, scheduler_param, save_path=None, load_path=None, **kwargs): super(Trainer, self).__init__() self.device = torch.device('cuda:0') self.kwargs = kwargs self.task_dict = task_dict self.task_num = len(task_dict) self.task_name = list(task_dict.keys()) self.rep_grad = rep_grad self.multi_input = multi_input self.scheduler_param = scheduler_param self.save_path = save_path self.load_path = load_path self._prepare_model(weighting, architecture, encoder_class, decoders) self._prepare_optimizer(optim_param, scheduler_param) self.meter = _PerformanceMeter(self.task_dict, self.multi_input) def _prepare_model(self, weighting, architecture, encoder_class, decoders): class MTLmodel(architecture, weighting): def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, kwargs): super(MTLmodel, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) self.init_param() self.model = MTLmodel(task_name=self.task_name, encoder_class=encoder_class, decoders=decoders, rep_grad=self.rep_grad, multi_input=self.multi_input, device=self.device, kwargs=self.kwargs['arch_args']).to(self.device) if self.load_path is not None: if os.path.isdir(self.load_path): self.load_path = os.path.join(self.load_path, 'best.pt') self.model.load_state_dict(torch.load(self.load_path), strict=False) print('Load Model from - {}'.format(self.load_path)) count_parameters(self.model) def _prepare_optimizer(self, optim_param, scheduler_param): optim_dict = { 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'adagrad': torch.optim.Adagrad, 'rmsprop': torch.optim.RMSprop, } scheduler_dict = { 'exp': torch.optim.lr_scheduler.ExponentialLR, 'step': torch.optim.lr_scheduler.StepLR, 'cos': torch.optim.lr_scheduler.CosineAnnealingLR, 'reduce': torch.optim.lr_scheduler.ReduceLROnPlateau, } optim_arg = {k: v for k, v in optim_param.items() if k != 'optim'} self.optimizer = optim_dict[optim_param['optim']](self.model.parameters(), **optim_arg) if scheduler_param is not None: scheduler_arg = {k: v for k, v in scheduler_param.items() if k != 'scheduler'} self.scheduler = scheduler_dict[scheduler_param['scheduler']](self.optimizer, **scheduler_arg) else: self.scheduler = None def _process_data(self, loader): try: data, label = loader[1].next() except: loader[1] = iter(loader[0]) data, label = loader[1].next() data = data.to(self.device, non_blocking=True) if not self.multi_input: for task in self.task_name: label[task] = label[task].to(self.device, non_blocking=True) else: label = label.to(self.device, non_blocking=True) return data, label
[docs] def process_preds(self, preds, task_name=None): r'''The processing of prediction for each task. - The default is no processing. If necessary, you can rewrite this function. - If ``multi_input`` is ``True``, ``task_name`` is valid and ``preds`` with type :class:`torch.Tensor` is the prediction of this task. - otherwise, ``task_name`` is invalid and ``preds`` is a :class:`dict` of name-prediction pairs of all tasks. Args: preds (dict or torch.Tensor): The prediction of ``task_name`` or all tasks. task_name (str): The string of task name. ''' return preds
def _compute_loss(self, preds, gts, task_name=None): if not self.multi_input: train_losses = torch.zeros(self.task_num).to(self.device) for tn, task in enumerate(self.task_name): train_losses[tn] = self.meter.losses[task]._update_loss(preds[task], gts[task]) else: train_losses = self.meter.losses[task_name]._update_loss(preds, gts) return train_losses def _prepare_dataloaders(self, dataloaders): if not self.multi_input: loader = [dataloaders, iter(dataloaders)] return loader, len(dataloaders) else: loader = {} batch_num = [] for task in self.task_name: loader[task] = [dataloaders[task], iter(dataloaders[task])] batch_num.append(len(dataloaders[task])) return loader, batch_num
[docs] def train(self, train_dataloaders, test_dataloaders, epochs, val_dataloaders=None, return_weight=False): r'''The training process of multi-task learning. Args: train_dataloaders (dict or torch.utils.data.DataLoader): The dataloaders used for training. \ If ``multi_input`` is ``True``, it is a dictionary of name-dataloader pairs. \ Otherwise, it is a single dataloader which returns data and a dictionary \ of name-label pairs in each iteration. test_dataloaders (dict or torch.utils.data.DataLoader): The dataloaders used for the validation or testing. \ The same structure with ``train_dataloaders``. epochs (int): The total training epochs. return_weight (bool): if ``True``, the loss weights will be returned. ''' train_loader, train_batch = self._prepare_dataloaders(train_dataloaders) train_batch = max(train_batch) if self.multi_input else train_batch self.batch_weight = np.zeros([self.task_num, epochs, train_batch]) self.model.train_loss_buffer = np.zeros([self.task_num, epochs]) self.model.epochs = epochs for epoch in range(epochs): self.model.epoch = epoch self.model.train() self.meter.record_time('begin') for batch_index in range(train_batch): if not self.multi_input: train_inputs, train_gts = self._process_data(train_loader) train_preds = self.model(train_inputs) train_preds = self.process_preds(train_preds) train_losses = self._compute_loss(train_preds, train_gts) self.meter.update(train_preds, train_gts) else: train_losses = torch.zeros(self.task_num).to(self.device) for tn, task in enumerate(self.task_name): train_input, train_gt = self._process_data(train_loader[task]) train_pred = self.model(train_input, task) train_pred = train_pred[task] train_pred = self.process_preds(train_pred, task) train_losses[tn] = self._compute_loss(train_pred, train_gt, task) self.meter.update(train_pred, train_gt, task) self.optimizer.zero_grad() w = self.model.backward(train_losses, **self.kwargs['weight_args']) if w is not None: self.batch_weight[:, epoch, batch_index] = w self.optimizer.step() self.meter.record_time('end') self.meter.get_score() self.model.train_loss_buffer[:, epoch] = self.meter.loss_item self.meter.display(epoch=epoch, mode='train') self.meter.reinit() if val_dataloaders is not None: self.meter.has_val = True val_improvement = self.test(val_dataloaders, epoch, mode='val', return_improvement=True) self.test(test_dataloaders, epoch, mode='test') if self.scheduler is not None: if self.scheduler_param['scheduler'] == 'reduce' and val_dataloaders is not None: self.scheduler.step(val_improvement) else: self.scheduler.step() if self.save_path is not None and self.meter.best_result['epoch'] == epoch: torch.save(self.model.state_dict(), os.path.join(self.save_path, 'best.pt')) print('Save Model {} to {}'.format(epoch, os.path.join(self.save_path, 'best.pt'))) self.meter.display_best_result() if return_weight: return self.batch_weight
[docs] def test(self, test_dataloaders, epoch=None, mode='test', return_improvement=False): r'''The test process of multi-task learning. Args: test_dataloaders (dict or torch.utils.data.DataLoader): If ``multi_input`` is ``True``, \ it is a dictionary of name-dataloader pairs. Otherwise, it is a single \ dataloader which returns data and a dictionary of name-label pairs in each iteration. epoch (int, default=None): The current epoch. ''' test_loader, test_batch = self._prepare_dataloaders(test_dataloaders) self.model.eval() self.meter.record_time('begin') with torch.no_grad(): if not self.multi_input: for batch_index in range(test_batch): test_inputs, test_gts = self._process_data(test_loader) test_preds = self.model(test_inputs) test_preds = self.process_preds(test_preds) test_losses = self._compute_loss(test_preds, test_gts) self.meter.update(test_preds, test_gts) else: for tn, task in enumerate(self.task_name): for batch_index in range(test_batch[tn]): test_input, test_gt = self._process_data(test_loader[task]) test_pred = self.model(test_input, task) test_pred = test_pred[task] test_pred = self.process_preds(test_pred) test_loss = self._compute_loss(test_pred, test_gt, task) self.meter.update(test_pred, test_gt, task) self.meter.record_time('end') self.meter.get_score() self.meter.display(epoch=epoch, mode=mode) improvement = self.meter.improvement self.meter.reinit() if return_improvement: return improvement