LibMTL

class Trainer(task_dict, weighting, architecture, encoder_class, decoders, rep_grad, multi_input, optim_param, scheduler_param, save_path=None, load_path=None, **kwargs)[source]

Bases: torch.nn.Module

A Multi-Task Learning Trainer.

This is a unified and extensible training framework for multi-task learning.

Parameters
  • task_dict (dict) – A dictionary of name-information pairs of type (str, dict). The sub-dictionary for each task has four entries whose keywords are named metrics, metrics_fn, loss_fn, weight and each of them corresponds to a list. The list of metrics has m strings, repersenting the name of m metrics for this task. The list of metrics_fn has two elements, i.e., the updating and score functions, meaning how to update thoes objectives in the training process and obtain the final scores, respectively. The list of loss_fn has m loss functions corresponding to each metric. The list of weight has m binary integers corresponding to each metric, where 1 means the higher the score is, the better the performance, 0 means the opposite.

  • weighting (class) – A weighting strategy class based on LibMTL.weighting.abstract_weighting.AbsWeighting.

  • architecture (class) – An architecture class based on LibMTL.architecture.abstract_arch.AbsArchitecture.

  • encoder_class (class) – A neural network class.

  • decoders (dict) – A dictionary of name-decoder pairs of type (str, torch.nn.Module).

  • rep_grad (bool) – If True, the gradient of the representation for each task can be computed.

  • multi_input (bool) – Is True if each task has its own input data, otherwise is False.

  • optim_param (dict) – A dictionary of configurations for the optimizier.

  • scheduler_param (dict) – A dictionary of configurations for learning rate scheduler. Set it to None if you do not use a learning rate scheduler.

  • kwargs (dict) – A dictionary of hyperparameters of weighting and architecture methods.

Note

It is recommended to use LibMTL.config.prepare_args() to return the dictionaries of optim_param, scheduler_param, and kwargs.

Examples:

import torch.nn as nn
from LibMTL import Trainer
from LibMTL.loss import CE_loss_fn
from LibMTL.metrics import acc_update_fun, acc_score_fun
from LibMTL.weighting import EW
from LibMTL.architecture import HPS
from LibMTL.model import ResNet18
from LibMTL.config import prepare_args

task_dict = {'A': {'metrics': ['Acc'],
                   'metrics_fn': [acc_update_fun, acc_score_fun],
                   'loss_fn': [CE_loss_fn],
                   'weight': [1]}}

decoders = {'A': nn.Linear(512, 31)}

# You can use command-line arguments and return configurations by ``prepare_args``.
# kwargs, optim_param, scheduler_param = prepare_args(params)
optim_param = {'optim': 'adam', 'lr': 1e-3, 'weight_decay': 1e-4}
scheduler_param = {'scheduler': 'step'}
kwargs = {'weight_args': {}, 'arch_args': {}}

trainer = Trainer(task_dict=task_dict,
                  weighting=EW,
                  architecture=HPS,
                  encoder_class=ResNet18,
                  decoders=decoders,
                  rep_grad=False,
                  multi_input=False,
                  optim_param=optim_param,
                  scheduler_param=scheduler_param,
                  **kwargs)
process_preds(self, preds, task_name=None)

The processing of prediction for each task.

  • The default is no processing. If necessary, you can rewrite this function.

  • If multi_input is True, task_name is valid and preds with type torch.Tensor is the prediction of this task.

  • otherwise, task_name is invalid and preds is a dict of name-prediction pairs of all tasks.

Parameters
  • preds (dict or torch.Tensor) – The prediction of task_name or all tasks.

  • task_name (str) – The string of task name.

train(self, train_dataloaders, test_dataloaders, epochs, val_dataloaders=None, return_weight=False)

The training process of multi-task learning.

Parameters
  • train_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for training. If multi_input is True, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.

  • test_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for the validation or testing. The same structure with train_dataloaders.

  • epochs (int) – The total training epochs.

  • return_weight (bool) – if True, the loss weights will be returned.

test(self, test_dataloaders, epoch=None, mode='test', return_improvement=False)

The test process of multi-task learning.

Parameters
  • test_dataloaders (dict or torch.utils.data.DataLoader) – If multi_input is True, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.

  • epoch (int, default=None) – The current epoch.