LibMTL
¶
- class Trainer(task_dict, weighting, architecture, encoder_class, decoders, rep_grad, multi_input, optim_param, scheduler_param, save_path=None, load_path=None, **kwargs)[source]¶
Bases:
torch.nn.Module
A Multi-Task Learning Trainer.
This is a unified and extensible training framework for multi-task learning.
- Parameters
task_dict (dict) – A dictionary of name-information pairs of type (
str
,dict
). The sub-dictionary for each task has four entries whose keywords are named metrics, metrics_fn, loss_fn, weight and each of them corresponds to alist
. The list of metrics hasm
strings, repersenting the name ofm
metrics for this task. The list of metrics_fn has two elements, i.e., the updating and score functions, meaning how to update thoes objectives in the training process and obtain the final scores, respectively. The list of loss_fn hasm
loss functions corresponding to each metric. The list of weight hasm
binary integers corresponding to each metric, where1
means the higher the score is, the better the performance,0
means the opposite.weighting (class) – A weighting strategy class based on
LibMTL.weighting.abstract_weighting.AbsWeighting
.architecture (class) – An architecture class based on
LibMTL.architecture.abstract_arch.AbsArchitecture
.encoder_class (class) – A neural network class.
decoders (dict) – A dictionary of name-decoder pairs of type (
str
,torch.nn.Module
).rep_grad (bool) – If
True
, the gradient of the representation for each task can be computed.multi_input (bool) – Is
True
if each task has its own input data, otherwise isFalse
.optim_param (dict) – A dictionary of configurations for the optimizier.
scheduler_param (dict) – A dictionary of configurations for learning rate scheduler. Set it to
None
if you do not use a learning rate scheduler.kwargs (dict) – A dictionary of hyperparameters of weighting and architecture methods.
Note
It is recommended to use
LibMTL.config.prepare_args()
to return the dictionaries ofoptim_param
,scheduler_param
, andkwargs
.Examples:
import torch.nn as nn from LibMTL import Trainer from LibMTL.loss import CE_loss_fn from LibMTL.metrics import acc_update_fun, acc_score_fun from LibMTL.weighting import EW from LibMTL.architecture import HPS from LibMTL.model import ResNet18 from LibMTL.config import prepare_args task_dict = {'A': {'metrics': ['Acc'], 'metrics_fn': [acc_update_fun, acc_score_fun], 'loss_fn': [CE_loss_fn], 'weight': [1]}} decoders = {'A': nn.Linear(512, 31)} # You can use command-line arguments and return configurations by ``prepare_args``. # kwargs, optim_param, scheduler_param = prepare_args(params) optim_param = {'optim': 'adam', 'lr': 1e-3, 'weight_decay': 1e-4} scheduler_param = {'scheduler': 'step'} kwargs = {'weight_args': {}, 'arch_args': {}} trainer = Trainer(task_dict=task_dict, weighting=EW, architecture=HPS, encoder_class=ResNet18, decoders=decoders, rep_grad=False, multi_input=False, optim_param=optim_param, scheduler_param=scheduler_param, **kwargs)
- process_preds(self, preds, task_name=None)¶
The processing of prediction for each task.
The default is no processing. If necessary, you can rewrite this function.
If
multi_input
isTrue
,task_name
is valid andpreds
with typetorch.Tensor
is the prediction of this task.otherwise,
task_name
is invalid andpreds
is adict
of name-prediction pairs of all tasks.
- Parameters
preds (dict or torch.Tensor) – The prediction of
task_name
or all tasks.task_name (str) – The string of task name.
- train(self, train_dataloaders, test_dataloaders, epochs, val_dataloaders=None, return_weight=False)¶
The training process of multi-task learning.
- Parameters
train_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for training. If
multi_input
isTrue
, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.test_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for the validation or testing. The same structure with
train_dataloaders
.epochs (int) – The total training epochs.
return_weight (bool) – if
True
, the loss weights will be returned.
- test(self, test_dataloaders, epoch=None, mode='test', return_improvement=False)¶
The test process of multi-task learning.
- Parameters
test_dataloaders (dict or torch.utils.data.DataLoader) – If
multi_input
isTrue
, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.epoch (int, default=None) – The current epoch.