LibMTL.trainer¶
- class Trainer(task_dict, weighting, architecture, encoder_class, decoders, rep_grad, multi_input, optim_param, scheduler_param, save_path=None, load_path=None, **kwargs)[source]¶
Bases:
torch.nn.ModuleA Multi-Task Learning Trainer.
This is a unified and extensible training framework for multi-task learning.
- Parameters:
task_dict (dict) – A dictionary of name-information pairs of type (
str,dict). The sub-dictionary for each task has four entries whose keywords are named metrics, metrics_fn, loss_fn, weight and each of them corresponds to alist. The list of metrics hasmstrings, repersenting the name ofmmetrics for this task. The list of metrics_fn has two elements, i.e., the updating and score functions, meaning how to update thoes objectives in the training process and obtain the final scores, respectively. The list of loss_fn hasmloss functions corresponding to each metric. The list of weight hasmbinary integers corresponding to each metric, where1means the higher the score is, the better the performance,0means the opposite.weighting (class) – A weighting strategy class based on
LibMTL.weighting.abstract_weighting.AbsWeighting.architecture (class) – An architecture class based on
LibMTL.architecture.abstract_arch.AbsArchitecture.encoder_class (class) – A neural network class.
decoders (dict) – A dictionary of name-decoder pairs of type (
str,torch.nn.Module).rep_grad (bool) – If
True, the gradient of the representation for each task can be computed.multi_input (bool) – Is
Trueif each task has its own input data, otherwise isFalse.optim_param (dict) – A dictionary of configurations for the optimizier.
scheduler_param (dict) – A dictionary of configurations for learning rate scheduler. Set it to
Noneif you do not use a learning rate scheduler.kwargs (dict) – A dictionary of hyperparameters of weighting and architecture methods.
Note
It is recommended to use
LibMTL.config.prepare_args()to return the dictionaries ofoptim_param,scheduler_param, andkwargs.Examples:
import torch.nn as nn from LibMTL import Trainer from LibMTL.loss import CE_loss_fn from LibMTL.metrics import acc_update_fun, acc_score_fun from LibMTL.weighting import EW from LibMTL.architecture import HPS from LibMTL.model import ResNet18 from LibMTL.config import prepare_args task_dict = {'A': {'metrics': ['Acc'], 'metrics_fn': [acc_update_fun, acc_score_fun], 'loss_fn': [CE_loss_fn], 'weight': [1]}} decoders = {'A': nn.Linear(512, 31)} # You can use command-line arguments and return configurations by ``prepare_args``. # kwargs, optim_param, scheduler_param = prepare_args(params) optim_param = {'optim': 'adam', 'lr': 1e-3, 'weight_decay': 1e-4} scheduler_param = {'scheduler': 'step'} kwargs = {'weight_args': {}, 'arch_args': {}} trainer = Trainer(task_dict=task_dict, weighting=EW, architecture=HPS, encoder_class=ResNet18, decoders=decoders, rep_grad=False, multi_input=False, optim_param=optim_param, scheduler_param=scheduler_param, **kwargs)
- process_preds(preds, task_name=None)[source]¶
The processing of prediction for each task.
The default is no processing. If necessary, you can rewrite this function.
If
multi_inputisTrue,task_nameis valid andpredswith typetorch.Tensoris the prediction of this task.otherwise,
task_nameis invalid andpredsis adictof name-prediction pairs of all tasks.
- Parameters:
preds (dict or torch.Tensor) – The prediction of
task_nameor all tasks.task_name (str) – The string of task name.
- train(train_dataloaders, test_dataloaders, epochs, val_dataloaders=None, return_weight=False)[source]¶
The training process of multi-task learning.
- Parameters:
train_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for training. If
multi_inputisTrue, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.test_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for the validation or testing. The same structure with
train_dataloaders.epochs (int) – The total training epochs.
return_weight (bool) – if
True, the loss weights will be returned.
- test(test_dataloaders, epoch=None, mode='test', return_improvement=False)[source]¶
The test process of multi-task learning.
- Parameters:
test_dataloaders (dict or torch.utils.data.DataLoader) – If
multi_inputisTrue, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.epoch (int, default=None) – The current epoch.