LibMTL: A PyTorch Library for Multi-Task Learning¶
Introduction¶
LibMTL
is an open-source library built on PyTorch for Multi-Task Learning (MTL). This library has the following three characteristics.
Unified:
LibMTL
provides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms.Comprehensive:
LibMTL
supports 84 MTL models combined by 7 architectures and 12 loss weighting strategies. Meanwhile,LibMTL
provides a fair comparison on 3 computer vision datasets.Extensible:
LibMTL
follows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel loss weighting strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support ofLibMTL
.
Supported Algorithms¶
LibMTL
currently supports the following algorithms:
12 loss weighting strategies.
Weighting Strategy | Venues | Comments |
---|---|---|
Equal Weighting (EW) | - | Implemented by us |
Gradient Normalization (GradNorm) | ICML 2018 | Implemented by us |
Uncertainty Weights (UW) | CVPR 2018 | Implemented by us |
MGDA | NeurIPS 2018 | Referenced from official PyTorch implementation |
Dynamic Weight Average (DWA) | CVPR 2019 | Referenced from official PyTorch implementation |
Geometric Loss Strategy (GLS) | CVPR 2019 workshop | Implemented by us |
Projecting Conflicting Gradient (PCGrad) | NeurIPS 2020 | Implemented by us |
Gradient sign Dropout (GradDrop) | NeurIPS 2020 | Implemented by us |
Impartial Multi-Task Learning (IMTL) | ICLR 2021 | Implemented by us |
Gradient Vaccine (GradVac) | ICLR 2021 Spotlight | Implemented by us |
Conflict-Averse Gradient descent (CAGrad) | NeurIPS 2021 | Referenced from official PyTorch implementation |
Random Loss Weighting (RLW) | arXiv | Implemented by us |
7 architectures.
Architecture | Venues | Comments |
---|---|---|
Hard Parameter Sharing (HPS) | ICML 1993 | Implemented by us |
Cross-stitch Networks (Cross_stitch) | CVPR 2016 | Implemented by us |
Multi-gate Mixture-of-Experts (MMoE) | KDD 2018 | Implemented by us |
Multi-Task Attention Network (MTAN) | CVPR 2019 | Referenced from official PyTorch implementation |
Customized Gate Control (CGC) | ACM RecSys 2020 Best Paper | Implemented by us |
Progressive Layered Extraction (PLE) | ACM RecSys 2020 Best Paper | Implemented by us |
DSelect-k | NeurIPS 2021 | Referenced from official TensorFlow implementation |
84 combinations of different architectures and loss weighting strategies.
Citation¶
If you find LibMTL
useful for your research or development, please cite the following:
@article{LibMTL,
title={LibMTL: A Python Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={arXiv preprint arXiv:2203.14338},
year={2022}
}
Contributors¶
LibMTL
is developed and maintained by Baijiong Lin and Yu Zhang.
Contact Us¶
If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to bj.lin.email@gmail.com
.
Acknowledgements¶
We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, and mtan.
Installation¶
Dependencies¶
To install LibMTL
, you need to setup the following libraries:
Python >= 3.7
torch >= 1.8.0
torchvision >= 0.9.0
numpy >= 1.20
User Installation¶
Create a virtual environment
conda create -n libmtl python=3.8 conda activate libmtl pip install torch==1.8.0 torchvision==0.9.0 numpy==1.20
Clone the repository
git clone https://github.com/median-research-group/LibMTL.git
Install
LibMTL
pip install -e .
Quick Start¶
We use the NYUv2 dataset [1] as an example to show how to use LibMTL
. More details and results are provided here.
Download Dataset¶
The NYUv2 dataset we used is pre-processed by mtan. You can download this dataset here. The directory structure is as follows:
*/nyuv2/
├── train
│ ├── depth
│ ├── image
│ ├── label
│ └── normal
└── val
├── depth
├── image
├── label
└── normal
The NYUv2 dataset is a MTL benchmark dataset, which includes three tasks: 13-class semantic segmentation, depth estimation, and surface normal prediction. image
contains the input images and label
, depth
, normal
contains the labels for three tasks, respectively. We train the MTL model with the data in train
and evaluate on val
.
Run a Model¶
The complete training code for the NYUv2 dataset is provided in examples/nyu. The file train_nyu.py
is the main file for training on the NYUv2 dataset.
You can find the command-line arguments by running the following command.
python train_nyu.py -h
For instance, running the following command will train a MTL model with LibMTL.weighting.EW
and LibMTL.architecture.HPS
on NYUv2 dataset.
python train_nyu.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step
If everything works fine, you will see the following outputs which includes the training configurations and the number of model parameters.
========================================
General Configuration:
Wighting: EW
Architecture: HPS
Rep_Grad: False
Multi_Input: False
Seed: 0
Device: cuda:0
Optimizer Configuration:
optim: adam
lr: 0.0001
weight_decay: 1e-05
Scheduler Configuration:
scheduler: step
step_size: 100
gamma: 0.5
========================================
Total Params: 71888721
Trainable Params: 71888721
Non-trainable Params: 0
========================================
Next, the results will be printed in following format.
LOG FORMAT | segmentation_LOSS mIoU pixAcc | depth_LOSS abs_err rel_err | normal_LOSS mean median <11.25 <22.5 <30 | TIME
Epoch: 0000 | TRAIN: 1.4417 0.2494 0.5717 | 1.4941 1.4941 0.5002 | 0.3383 43.1593 38.2601 0.0913 0.2639 0.3793 | Time: 81.6612 | TEST: 1.0898 0.3589 0.6676 | 0.7027 0.7027 0.2615 | 0.2143 32.8732 29.4323 0.1734 0.3878 0.5090 | Time: 11.9699
Epoch: 0001 | TRAIN: 0.8958 0.4194 0.7201 | 0.7011 0.7011 0.2448 | 0.1993 31.5235 27.8404 0.1826 0.4060 0.5361 | Time: 82.2399 | TEST: 0.9980 0.4189 0.6868 | 0.6274 0.6274 0.2347 | 0.1991 31.0144 26.5077 0.2065 0.4332 0.5551 | Time: 12.0278
If the training process ends, the best result on val
will be printed as follows.
Best Result: Epoch 65, result {'segmentation': [0.5377492904663086, 0.7544658184051514], 'depth': [0.38453552363844823, 0.1605487049810748], 'normal': [23.573742, 17.04381, 0.35038458555943763, 0.609274380451927, 0.7207172795833373]}
References¶
- 1
Nathan Silberman, Derek Hoiem, Pushmeet Kohli, and Rob Fergus. Indoor segmentation and support inference from rgbd images. In Proceedings of the 8th European Conference on Computer Vision, 746–760. 2012.
What is Multi-Task Learning?¶
Multi-Task Learning (MTL) is an active research field in machine learning. It is a learning paradigm which aims to jointly learn several related tasks to improve their generalization performance by leveraging common knowledge among them. In recent years, many researchers have successfully applied MTL to different fields such as computer vision, natural language processing, reinforcement learning, recommendation system and so on.
The recent studies of MTL mainly focus on two perspectives, network architecture design and loss weighting. We implement some general and representative methods in LibMTL
.
For more relevant introduction, please refer to [1, 2, 3, 4].
Network Architecture¶
In the design of network architectures, the simplest and most popular method is the hard parameter sharing (HPS, LibMTL.architecture.HPS
), as shown in Fig. 1, where an encoder is shared among all the tasks and each task has its own specific decoder. Since most of the parameters are shared among tasks, such architecture easily causes negative sharing when tasks are not related enough. To better deal with task relationships, different MTL architectures have been proposed. LibMTL
supports several state-of-the-art architectures, please refer to LibMTL.architecture
for details.
There are usually two types of MTL problems: the single-input problem and the multi-input problem. The single-input problem, as shown in the left of Fig. 1, means an input data has an output for each task or equivalently all tasks share the input data. The NYUv2 dataset is an example of this problem. The multi-input problem, as shown in the right of Fig. 1, indicates each task has its own input data. The Office-31 and Office-Home datasets belong to such problem. LibMTL
has unified these two cases in a training framework and you just need to set the command-line argument multi_input
correctly.

An illustration of the single-input problem (left) and the multi-input problem (right), using hard parameter sharing pattern as an example.¶
Weighting Strategy¶
Balancing multiple losses corresponding to multiple tasks is another way to deal with task relationships since the shared parameters are updated by all the task losses. Thus, different methods have been proposed to balance losses or gradients. LibMTL
supports several state-of-the-art weighting strategies, please see LibMTL.weighting
for details.
Some gradient balancing methods such as MGDA (LibMTL.weighting.MGDA
) need to compute the gradient for each task first and then calculate the aggregated gradient in various ways. To reduce the computational cost, it can use the gradients of the representations after the encoder (abbreviated as rep-grad) to approximate the gradients of shared parameters (abbreviated as param-grad).
The PyTorch implemention of rep-grad is shown in Fig. 2. We need to separate the computational graph into two parts by the detach
operation. LibMTL
has unified the two cases in a training framework and you just need to set the command-line argument rep_grad
correctly. Besides, the argument rep_grad
does not conflict with multi_input
.

An illustration of how to compute the gradient for representation.¶
References¶
- 1
Yu Zhang and Qiang Yang. A survey on multi-task learning. IEEE Transactions on Knowledge and Data Engineering, 2021.
- 2
Simon Vandenhende, Stamatios Georgoulis, Wouter Van Gansbeke, Marc Proesmans, Dengxin Dai, and Luc Van Gool. Multi-task learning for dense prediction tasks: a survey. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021.
- 3
Baijiong Lin, Feiyang Ye, and Yu Zhang. A closer look at loss weighting in multi-task learning. arXiv preprint arXiv:2111.10603, 2021.
- 4
Michael Crawshaw. Multi-task learning with deep neural networks: a survey. arXiv preprint arXiv:2009.09796, 2020.
Overall Framework¶
LibMTL
provides a unified framework to train a MTL model with several architectures and weighting strategies on benchmark datasets. The overall framework consists of nine modules as introduced below.
The Dataloader module is responsible for data pre-processing and loading.
The LibMTL.loss module defines loss functions for each task.
The LibMTL.metrics module defines evaluation metrics for all the tasks.
The LibMTL.config module is responsible for all the configuration parameters involved in the training process, such as the corresponding MTL setting (i.e. the multi-input case or not), the potential hyper-parameters of loss weighting strategies and architectures, the training configuration (e.g., the batch size, the running epoch, the random seed, and the learning rate), and so on. This module adopts command-line arguments to enable users to conveniently set those configuration parameters.
The LibMTL.Trainer module provides a unified framework for the training process under different MTL settings and for different MTL approaches
The LibMTL.utils module implements some useful functionalities for the training process such as calculating the total number of parameters in an MTL model.
The LibMTL.architecture module contains the implementations of various architectures in MTL.
The LibMTL.weighting module contains the implementations of various loss weighting strategies in MTL.
The LibMTL.model module includes some popular backbone networks (e.g., ResNet).

Run a Benchmark¶
Here we introduce some MTL benchmark datasets and show how to run models on them for a fair comparison.
NYUv2¶
The NYUv2 dataset [1] is an indoor scene understanding dataset, which consists of video sequences recorded by the RGB and Depth cameras in the Microsoft Kinect. It contains 795 and 654 images with ground-truths for training and validation, respectively.
We use the pre-processed NYUv2 dataset in [2], which can be downloaded here. Each input image has been resized to 3x288x384 and has labels for three tasks, including 13-class semantic segmentation, depth estimation, and surface normal prediction. Thus, it is a single-input problem, which means multi_input
must be False
.
The training codes are mainly modified from mtan and available in examples/nyu
. We use DeepLabV3+ architecture [3], where a ResNet-50 network pretrained on the ImageNet dataset with dilated convolutions [4] is used as a shared encoder among tasks and the Atrous Spatial Pyramid Pooling (ASPP) module [3] is used as task-specific head for each task.
Following [2], the evaluation metrics of three tasks are adopted as follows. Mean Intersection over Union (mIoU) and Pixel Accuracy (Pix Acc) are used for the semantic segmentation task. Absolute and relative errors (denoted by Abs Err and Rel Err) are used for the depth estimation task. Five metrics are used for the surface normal estimation task: mean absolute of the error (Mean), median absolute of the error (Median), and percentage of pixels with the angular error below a threshold \({\epsilon}\) with \({\epsilon}\) as \({11.25^{\circ}, 22.5^{\circ}, 30^{\circ}}\) (abbreviated as <11.25, <22.5, <30), respectively. Among them, higher scores of mIoU, Pix Acc, <11.25, <22.5, and <30 mean better performance and lower scores of Abs Err, Rel Err, Mean, and Median indicate better performance.
Run a Model¶
The script train_nyu.py
is the main file for training and evaluating an MTL model on the NYUv2 dataset. A set of command-line arguments is provided to allow users to adjust the training configuration.
Some important arguments are described as follows.
weighting
: The weighting strategy. Refer to here.arch
: The MTL architecture. Refer to here.gpu_id
: The id of gpu. The default value is ‘0’.seed
: The random seed for reproducibility. The default value is 0.scheduler
: The type of the learning rate scheduler. We recommend to use ‘step’ here.optim
: The type of the optimizer. We recommend to use ‘adam’ here.dataset_path
: The path of the NYUv2 dataset.aug
: IfTrue
, the model is trained with a data augmentation.train_bs
: The batch size of training data. The default value is 8.test_bs
: The batch size of test data. The default value is 8.
The complete command-line arguments and their descriptions can be found by running the following command.
python train_nyu.py -h
If you understand those command-line arguments, you can train an MTL model by executing the following command.
python train_nyu.py --weighting WEIGHTING --arch ARCH --dataset_path PATH/nyuv2 --gpu_id GPU_ID --scheduler step
References¶
- 1
Nathan Silberman, Derek Hoiem, Pushmeet Kohli, and Rob Fergus. Indoor segmentation and support inference from rgbd images. In Proceedings of the 8th European Conference on Computer Vision, 746–760. 2012.
- 2(1,2)
Shikun Liu, Edward Johns, and Andrew J. Davison. End-to-end multi-task learning with attention. In Proceedings of IEEE Conference on Computer Vision and Pattern Recognition, 1871–1880. 2019.
- 3(1,2)
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam. Encoder-decoder with atrous separable convolution for semantic image segmentation. In Proceedings of the 14th European Conference on Computer Vision, volume 11211, 833–851. 2018.
- 4
Fisher Yu, Vladlen Koltun, and Thomas A. Funkhouser. Dilated residual networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 636–644. 2017.
Office-31 and Office-Home¶
The Office-31 dataset [1] consists of three classification tasks on three domains: Amazon, DSLR, and Webcam, where each task has 31 object categories. It can be download here. This dataset contains 4,110 labeled images and we randomly split these samples, with 60% for training, 20% for validation, and the rest 20% for testing.
The Office-Home dataset [2] has four classification tasks on four domains: Artistic images (abbreviated as Art), Clip art, Product images, and Real-world images. It can be download here. This dataset has 15,500 labeled images in total and each domain contains 65 classes. We divide the entire data into the same proportion as the Office-31 dataset.
Both datasets belong to the multi-input setting in MTL. Thus, the multi_input
must be True
for both of the two office datasets.
The training codes are available in examples/office
. We use the ResNet-18 network pretrained on the ImageNet dataset followed by a fully connected layer as a shared encoder among tasks and a fully connected layer is applied as a task-specific output layer for each task. All the input images are resized to 3x224x224.
Run a Model¶
The script train_office.py
is the main file for training and evaluating a MTL model on the Office-31 or Office-Home dataset. A set of command-line arguments is provided to allow users to adjust the training parameter configuration.
Some important arguments are described as follows.
weighting
: The weighting strategy. Refer to here.arch
: The MTL architecture. Refer to here.gpu_id
: The id of gpu. The default value is ‘0’.seed
: The random seed for reproducibility. The default value is 0.optim
: The type of the optimizer. We recommend to use ‘adam’ here.dataset
: Training on Office-31 or Office-Home. Options: ‘office-31’, ‘office-home’.dataset_path
: The path of the Office-31 or Office-Home dataset.bs
: The batch size of training, validation, and test data. The default value is 64.
The complete command-line arguments and their descriptions can be found by running the following command.
python train_office.py -h
If you understand those command-line arguments, you can train a MTL model by running a command like this.
python train_office.py --weighting WEIGHTING --arch ARCH --dataset_path PATH --gpu_id GPU_ID --multi_input
References¶
- 1
Kate Saenko, Brian Kulis, Mario Fritz, and Trevor Darrell. Adapting visual category models to new domains. In Proceedings of the 6th European Conference on Computer Vision, 213–226. 2010.
- 2
Hemanth Venkateswara, Jose Eusebio, Shayok Chakraborty, and Sethuraman Panchanathan. Deep hashing network for unsupervised domain adaptation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 5018–5027. 2017.
Apply to a New Dataset¶
Here we would like to introduce how to apply LibMTL
to a new dataset.
Define a MTL problem¶
Firstly, you need to know the type of this MTL problem (i.e. a single-input problem or a multi-input problem, refer to here) and the information of each task, including the task’s name, evaluation metrics, loss functions, and indicators determined whether the higher the metric score is, the better the performance is.
The multi_input
is a command-line argument and all tasks’ information needs to be defined as a dictionary. LibMTL
provides some common loss functions and metrics, and refer to LibMTL.loss
and LibMTL.metrics
, respectively. Some examples are listed as follows.
Example 1 (The Office-31 Dataset)¶
from LibMTL.loss import CELoss
from LibMTL.metrics import AccMetric
# define tasks
task_name = ['amazon', 'dslr', 'webcam']
task_dict = {task: {'metrics': ['Acc'],
'metrics_fn': AccMetric(),
'loss_fn': CELoss(),
'weight': [1]} for task in task_name}
Besides, LibMTL
also supports to customize new losses and metrics. For example, if we would like to develop the metric classes for the segmentation task on the NYUv2 dataset, we need to inherit LibMTL.metrics.AbsMetric
and rewrite the corresponding methods like update_fun()
, score_fun()
, and reinit()
. Please see LibMTL.metrics.AbsMetric
for details. The loss class for segmentation is customized similarly. Please refer to LibMTL.loss.AbsLoss
for details.
Example 2 (The NYUv2 Dataset)¶
from LibMTL.metrics import AbsMetric
# seg
class SegMetric(AbsMetric):
def __init__(self):
super(SegMetric, self).__init__()
self.num_classes = 13
self.record = torch.zeros((self.num_classes, self.num_classes), dtype=torch.int64)
def update_fun(self, pred, gt):
self.record = self.record.to(pred.device)
pred = pred.softmax(1).argmax(1).flatten()
gt = gt.long().flatten()
k = (gt >= 0) & (gt < self.num_classes)
inds = self.num_classes * gt[k].to(torch.int64) + pred[k]
self.record += torch.bincount(inds, minlength=self.num_classes**2).reshape(self.num_classes, self.num_classes)
def score_fun(self):
h = self.record.float()
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
acc = torch.diag(h).sum() / h.sum()
return [torch.mean(iu).item(), acc.item()]
def reinit(self):
self.record = torch.zeros((self.num_classes, self.num_classes), dtype=torch.int64)
The customized loss and metric classes of three tasks on the NYUv2 dataset are put in examples/nyu/utils.py
. After that, the three-task MTL problem on the NYUv2 dataset is defined as follows.
from utils import *
# define tasks
task_dict = {'segmentation': {'metrics':['mIoU', 'pixAcc'],
'metrics_fn': SegMetric(),
'loss_fn': SegLoss(),
'weight': [1, 1]},
'depth': {'metrics':['abs_err', 'rel_err'],
'metrics_fn': DepthMetric(),
'loss_fn': DepthLoss(),
'weight': [0, 0]},
'normal': {'metrics':['mean', 'median', '<11.25', '<22.5', '<30'],
'metrics_fn': NormalMetric(),
'loss_fn': NormalLoss(),
'weight': [0, 0, 1, 1, 1]}}
Prepare Dataloaders¶
Secondly, you need to prepare the dataloaders with a correct format. For a multi-input problem like the Office-31 datatset, each task has its own dataloader and all dataloaders are put in a dictionary with the task names as the corresponding keys.
Example 1 (The Office-31 Dataset)¶
train_dataloaders = {'amazon': amazon_dataloader,
'dslr': dslr_dataloader,
'webcam': webcam_dataloader}
For single-input problem like the NYUv2 dataset, all tasks share a common dataloader, which outputs a list in every iteration. The first element of this list is the input data tensor and the second is a dictionary of the label tensors with the task names as the corresponding keys. An example is shown as follows.
Example 2 (The NYUv2 Dataset)¶
nyuv2_train_loader = xx
# print(iter(nyuv2_train_loader).next())
# [torch.Tensor, {'segmentation': torch.Tensor,
# 'depth': torch.Tensor,
# 'normal': torch.Tensor}]
Define Encoder and Decoders¶
Thirdly, you need to define the shared encoder and task-specific decoders. LibMTL
provides some neural networks like ResNet-based network. Please see LibMTL.model
for details. Also, you can customize the encoder and decoders.
Note that the encoder does not be instantiated while the decoders should be instantiated.
Example 1 (The Office-31 Dataset)¶
import torch
import torch.nn as nn
from LibMTL.model import resnet18
# define encoder and decoders
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
hidden_dim = 512
self.resnet_network = resnet18(pretrained=True)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.hidden_layer_list = [nn.Linear(512, hidden_dim),
nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.5)]
self.hidden_layer = nn.Sequential(*self.hidden_layer_list)
# initialization
self.hidden_layer[0].weight.data.normal_(0, 0.005)
self.hidden_layer[0].bias.data.fill_(0.1)
def forward(self, inputs):
out = self.resnet_network(inputs)
out = torch.flatten(self.avgpool(out), 1)
out = self.hidden_layer(out)
return out
decoders = nn.ModuleDict({task: nn.Linear(512, class_num) for task in task_name})
If the customized encoder is a ResNet-based network and you would like to use LibMTL.architecture.MTAN
, please make sure that the encoder has an attribute named resnet_network
corresponding to the ResNet network.
Example 2 (The NYUv2 Dataset)¶
from aspp import DeepLabHead
from LibMTL.model import resnet_dilated
# define encoder and decoders
def encoder_class():
return resnet_dilated('resnet50')
num_out_channels = {'segmentation': 13, 'depth': 1, 'normal': 3}
decoders = nn.ModuleDict({task: DeepLabHead(encoder.feature_dim,
num_out_channels[task]) for task in list(task_dict.keys())})
Instantiate the Training Framework¶
Fourthly, you need to instantiate the training framework. Please see LibMTL.Trainer
for more details.
Example 1 (The Office-31 Dataset)¶
from LibMTL import Trainer
officeModel = Trainer(task_dict=task_dict,
weighting=weighting_method.__dict__[params.weighting],
architecture=architecture_method.__dict__[params.arch],
encoder_class=Encoder,
decoders=decoders,
rep_grad=params.rep_grad,
multi_input=params.multi_input,
optim_param=optim_param,
scheduler_param=scheduler_param,
**kwargs)
Also, you can inherit the LibMTL.Trainer
class and rewrite some functions like process_preds()
.
Example 2 (The NYUv2 Dataset)¶
from LibMTL import Trainer
class NYUtrainer(Trainer):
def __init__(self, task_dict, weighting, architecture, encoder_class,
decoders, rep_grad, multi_input, optim_param, scheduler_param, **kwargs):
super(NYUtrainer, self).__init__(task_dict=task_dict,
weighting=weighting_method.__dict__[weighting],
architecture=architecture_method.__dict__[architecture],
encoder_class=encoder_class,
decoders=decoders,
rep_grad=rep_grad,
multi_input=multi_input,
optim_param=optim_param,
scheduler_param=scheduler_param,
**kwargs)
def process_preds(self, preds):
img_size = (288, 384)
for task in self.task_name:
preds[task] = F.interpolate(preds[task], img_size, mode='bilinear', align_corners=True)
return preds
NYUmodel = NYUtrainer(task_dict=task_dict,
weighting=params.weighting,
architecture=params.arch,
encoder_class=encoder_class,
decoders=decoders,
rep_grad=params.rep_grad,
multi_input=params.multi_input,
optim_param=optim_param,
scheduler_param=scheduler_param,
**kwargs)
Run a Model¶
Finally, you can train the model by using the train()
function like this.
officeModel.train(train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
epochs=100)
When the training process ends, the best results on the test dataset will be printed automatically. Please see LibMTL.Trainer.train()
and LibMTL.utils.count_improvement()
for details.
Customize an Architecture¶
Here we introduce how to customize a new architecture with the support of LibMTL
.
Create a New Architecture Class¶
Firstly, you need to create a new architecture class by inheriting class LibMTL.architecture.AbsArchitecture
.
from LibMTL.architecture import AbsArchitecture
class NewArchitecture(AbsArchitecture):
def __init__(self, task_name, encoder_class, decoders, rep_grad,
multi_input, device, **kwargs):
super(NewArchitecture, self).__init__(task_name, encoder_class, decoders, rep_grad,
multi_input, device, **kwargs)
Rewrite Relevant Methods¶
There are four important functions in LibMTL.architecture.AbsArchitecture
.
forward()
: The forward function and its input/output format can be found inLibMTL.architecture.AbsArchitecture.forward()
. To rewrite this function, you need to consider the case ofsingle-input
andmulti-input
(refer to here) and the case ofrep-grad
andparam-grad
(refer to here) if you want to combine your architecture with more weighting strategies or apply your architecture to more datasets.get_share_params()
: This function is used to return the shared parameters of the model. It returns all the parameters of the encoder by default. You can rewrite it if necessary.zero_grad_share_params()
: This function is used to set gradients of the shared parameters to zero. It will set the gradients of all the encoder parameters to zero by default. You can rewrite it if necessary._prepare_rep()
: This function is used to compute the gradients for representations. More details can be found here.
Customize a Weighting Strategy¶
Here we introduce how to customize a new weighting strategy with the support of LibMTL
.
Create a New Weighting Class¶
Firstly, you need to create a new weighting class by inheriting class LibMTL.weighting.AbsWeighting
.
from LibMTL.weighting import AbsWeighting
class NewWeighting(AbsWeighting):
def __init__(self):
super(NewWeighting, self).__init__()
Rewrite Relevant Methods¶
There are four important functions in LibMTL.weighting.AbsWeighting
.
backward()
: It is the main function of a weighting strategy whose input and output formats can be found inLibMTL.weighting.AbsWeighting.backward()
. To rewrite this function, you need to consider the case ofsingle-input
andmulti-input
(refer to here) and the case ofrep-grad
andparam-grad
(refer to here) if you want to combine your weighting method with more architectures or apply your method to more datasets.init_param()
: This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary._get_grads()
: This function is used to return the gradients of representations or shared parameters (corresponding to the case ofrep-grad
andparam-grad
, respectively)._backward_new_grads()
: This function is used to reset the gradients and make a backward pass (corresponding to the case ofrep-grad
andparam-grad
, respectively).
The _get_grads()
and _backward_new_grads()
functions are very useful to rewrite the backward()
function and you can find more details here.
LibMTL
¶
- class Trainer(task_dict, weighting, architecture, encoder_class, decoders, rep_grad, multi_input, optim_param, scheduler_param, save_path=None, load_path=None, **kwargs)[source]¶
Bases:
torch.nn.Module
A Multi-Task Learning Trainer.
This is a unified and extensible training framework for multi-task learning.
- Parameters
task_dict (dict) – A dictionary of name-information pairs of type (
str
,dict
). The sub-dictionary for each task has four entries whose keywords are named metrics, metrics_fn, loss_fn, weight and each of them corresponds to alist
. The list of metrics hasm
strings, repersenting the name ofm
metrics for this task. The list of metrics_fn has two elements, i.e., the updating and score functions, meaning how to update thoes objectives in the training process and obtain the final scores, respectively. The list of loss_fn hasm
loss functions corresponding to each metric. The list of weight hasm
binary integers corresponding to each metric, where1
means the higher the score is, the better the performance,0
means the opposite.weighting (class) – A weighting strategy class based on
LibMTL.weighting.abstract_weighting.AbsWeighting
.architecture (class) – An architecture class based on
LibMTL.architecture.abstract_arch.AbsArchitecture
.encoder_class (class) – A neural network class.
decoders (dict) – A dictionary of name-decoder pairs of type (
str
,torch.nn.Module
).rep_grad (bool) – If
True
, the gradient of the representation for each task can be computed.multi_input (bool) – Is
True
if each task has its own input data, otherwise isFalse
.optim_param (dict) – A dictionary of configurations for the optimizier.
scheduler_param (dict) – A dictionary of configurations for learning rate scheduler. Set it to
None
if you do not use a learning rate scheduler.kwargs (dict) – A dictionary of hyperparameters of weighting and architecture methods.
Note
It is recommended to use
LibMTL.config.prepare_args()
to return the dictionaries ofoptim_param
,scheduler_param
, andkwargs
.Examples:
import torch.nn as nn from LibMTL import Trainer from LibMTL.loss import CE_loss_fn from LibMTL.metrics import acc_update_fun, acc_score_fun from LibMTL.weighting import EW from LibMTL.architecture import HPS from LibMTL.model import ResNet18 from LibMTL.config import prepare_args task_dict = {'A': {'metrics': ['Acc'], 'metrics_fn': [acc_update_fun, acc_score_fun], 'loss_fn': [CE_loss_fn], 'weight': [1]}} decoders = {'A': nn.Linear(512, 31)} # You can use command-line arguments and return configurations by ``prepare_args``. # kwargs, optim_param, scheduler_param = prepare_args(params) optim_param = {'optim': 'adam', 'lr': 1e-3, 'weight_decay': 1e-4} scheduler_param = {'scheduler': 'step'} kwargs = {'weight_args': {}, 'arch_args': {}} trainer = Trainer(task_dict=task_dict, weighting=EW, architecture=HPS, encoder_class=ResNet18, decoders=decoders, rep_grad=False, multi_input=False, optim_param=optim_param, scheduler_param=scheduler_param, **kwargs)
- process_preds(self, preds, task_name=None)¶
The processing of prediction for each task.
The default is no processing. If necessary, you can rewrite this function.
If
multi_input
isTrue
,task_name
is valid andpreds
with typetorch.Tensor
is the prediction of this task.otherwise,
task_name
is invalid andpreds
is adict
of name-prediction pairs of all tasks.
- Parameters
preds (dict or torch.Tensor) – The prediction of
task_name
or all tasks.task_name (str) – The string of task name.
- train(self, train_dataloaders, test_dataloaders, epochs, val_dataloaders=None, return_weight=False)¶
The training process of multi-task learning.
- Parameters
train_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for training. If
multi_input
isTrue
, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.test_dataloaders (dict or torch.utils.data.DataLoader) – The dataloaders used for the validation or testing. The same structure with
train_dataloaders
.epochs (int) – The total training epochs.
return_weight (bool) – if
True
, the loss weights will be returned.
- test(self, test_dataloaders, epoch=None, mode='test', return_improvement=False)¶
The test process of multi-task learning.
- Parameters
test_dataloaders (dict or torch.utils.data.DataLoader) – If
multi_input
isTrue
, it is a dictionary of name-dataloader pairs. Otherwise, it is a single dataloader which returns data and a dictionary of name-label pairs in each iteration.epoch (int, default=None) – The current epoch.
LibMTL.architecture
¶
- class AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
torch.nn.Module
An abstract class for MTL architectures.
- Parameters
task_name (list) – A list of strings for all tasks.
encoder_class (class) – A neural network class.
decoders (dict) – A dictionary of name-decoder pairs of type (
str
,torch.nn.Module
).rep_grad (bool) – If
True
, the gradient of the representation for each task can be computed.multi_input (bool) – Is
True
if each task has its own input data, otherwise isFalse
.device (torch.device) – The device where model and data will be allocated.
kwargs (dict) – A dictionary of hyperparameters of architectures.
- forward(self, inputs, task_name=None)¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.
- class HPS(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Hard Parameter Sharing (HPS).
This method is proposed in Multitask Learning: A Knowledge-Based Source of Inductive Bias (ICML 1993) and implemented by us.
- class Cross_stitch(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Cross-stitch Networks (Cross_stitch).
This method is proposed in Cross-stitch Networks for Multi-task Learning (CVPR 2016) and implemented by us.
Warning
Cross_stitch
does not work with multiple inputs MTL problem, i.e.,multi_input
must beFalse
.Cross_stitch
is only supported by ResNet-based encoders.
- class MMoE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Multi-gate Mixture-of-Experts (MMoE).
This method is proposed in Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) and implemented by us.
- Parameters
img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.
num_experts (int) – The number of experts shared for all tasks. Each expert is an encoder network.
- forward(self, inputs, task_name=None)¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.
- class MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Multi-Task Attention Network (MTAN).
This method is proposed in End-To-End Multi-Task Learning With Attention (CVPR 2019) and implemented by modifying from the official PyTorch implementation.
Warning
MTAN
is only supported by ResNet-based encoders.- forward(self, inputs, task_name=None)¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.
- class CGC(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.MMoE.MMoE
Customized Gate Control (CGC).
This method is proposed in Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) and implemented by us.
- Parameters
img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.
num_experts (list) – The numbers of experts shared by all the tasks and specific to each task, respectively. Each expert is an encoder network.
- forward(self, inputs, task_name=None)¶
- class PLE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Progressive Layered Extraction (PLE).
This method is proposed in Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) and implemented by us.
- Parameters
img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.
num_experts (list) – The numbers of experts shared by all the tasks and specific to each task, respectively. Each expert is an encoder network.
Warning
- forward(self, inputs, task_name=None)¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.
- class DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.MMoE.MMoE
DSelect-k.
This method is proposed in DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) and implemented by modifying from the official TensorFlow implementation.
- Parameters
img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.
num_experts (int) – The number of experts shared by all the tasks. Each expert is an encoder network.
num_nonzeros (int) – The number of selected experts.
kgamma (float, default=1.0) – A scaling parameter for the smooth-step function.
- forward(self, inputs, task_name=None)¶
- class LTB(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Learning To Branch (LTB).
This method is proposed in Learning to Branch for Multi-Task Learning (ICML 2020) and implemented by us.
Warning
- forward(self, inputs, task_name=None)¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
LibMTL.model
¶
- resnet18(pretrained=False, progress=True, **kwargs)[source]¶
ResNet-18 model from “Deep Residual Learning for Image Recognition”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnet34(pretrained=False, progress=True, **kwargs)[source]¶
ResNet-34 model from “Deep Residual Learning for Image Recognition”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnet50(pretrained=False, progress=True, **kwargs)[source]¶
ResNet-50 model from “Deep Residual Learning for Image Recognition”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnet101(pretrained=False, progress=True, **kwargs)[source]¶
ResNet-101 model from “Deep Residual Learning for Image Recognition”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnet152(pretrained=False, progress=True, **kwargs)[source]¶
ResNet-152 model from “Deep Residual Learning for Image Recognition”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnext50_32x4d(pretrained=False, progress=True, **kwargs)[source]¶
ResNeXt-50 32x4d model from “Aggregated Residual Transformation for Deep Neural Networks”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnext101_32x8d(pretrained=False, progress=True, **kwargs)[source]¶
ResNeXt-101 32x8d model from “Aggregated Residual Transformation for Deep Neural Networks”
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- wide_resnet50_2(pretrained=False, progress=True, **kwargs)[source]¶
Wide ResNet-50-2 model from “Wide Residual Networks”
The model is the same as ResNet except for the number of bottleneck channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g., the last block in ResNet-50 has 2048-512-2048 channels, while in wide ResNet-50-2 there are 2048-1024-2048.
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- wide_resnet101_2(pretrained=False, progress=True, **kwargs)[source]¶
Wide ResNet-101-2 model from “Wide Residual Networks”
The model is the same as ResNet except for the number of bottleneck channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g., the last block in ResNet-101 has 2048-512-2048 channels, while in wide ResNet-101-2 there are 2048-1024-2048.
- Parameters
pretrained (bool) – If True, returns a model pre-trained on the ImageNet dataset.
progress (bool) – If True, displays a progress bar of the download to stderr.
- resnet_dilated(basenet, pretrained=True, dilate_scale=8)[source]¶
Dilated Residual Network models from “Dilated Residual Networks”
- Parameters
basenet (str) – The type of ResNet.
pretrained (bool) – If True, returns a model pre-trained on ImageNet.
dilate_scale ({8, 16}, default=8) – The type of dilating process.
LibMTL.weighting
¶
- class AbsWeighting[source]¶
Bases:
torch.nn.Module
An abstract class for weighting strategies.
- init_param(self)¶
Define and initialize some trainable parameters required by specific weighting methods.
- property backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class EW[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Equal Weighting (EW).
The loss weight for each task is always
1 / T
in every iteration, whereT
denotes the number of tasks.- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class GradNorm¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Gradient Normalization (GradNorm).
This method is proposed in GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) and implemented by us.
- Parameters
alpha (float, default=1.5) – The strength of the restoring force which pulls tasks back to a common training rate.
- init_param(self)¶
Define and initialize some trainable parameters required by specific weighting methods.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class MGDA¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Multiple Gradient Descent Algorithm (MGDA).
This method is proposed in Multi-Task Learning as Multi-Objective Optimization (NeurIPS 2018) and implemented by modifying from the official PyTorch implementation.
- Parameters
mgda_gn ({'none', 'l2', 'loss', 'loss+'}, default='none') – The type of gradient normalization.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class UW¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Uncertainty Weights (UW).
This method is proposed in Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) and implemented by us.
- init_param(self)¶
Define and initialize some trainable parameters required by specific weighting methods.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class DWA[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Dynamic Weight Average (DWA).
This method is proposed in End-To-End Multi-Task Learning With Attention (CVPR 2019) and implemented by modifying from the official PyTorch implementation.
- Parameters
T (float, default=2.0) – The softmax temperature.
- backward(self, losses, **kwargs)¶
- class GLS¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Geometric Loss Strategy (GLS).
This method is proposed in MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning (CVPR 2019 workshop) and implemented by us.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class GradDrop¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Gradient Sign Dropout (GradDrop).
This method is proposed in Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout (NeurIPS 2020) and implemented by us.
- Parameters
leak (float, default=0.0) – The leak parameter for the weighting matrix.
Warning
GradDrop is not supported by parameter gradients, i.e.,
rep_grad
must beTrue
.- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class PCGrad[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Project Conflicting Gradients (PCGrad).
This method is proposed in Gradient Surgery for Multi-Task Learning (NeurIPS 2020) and implemented by us.
Warning
PCGrad is not supported by representation gradients, i.e.,
rep_grad
must beFalse
.- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class GradVac[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Gradient Vaccine (GradVac).
This method is proposed in Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) and implemented by us.
- Parameters
GradVac_beta (float, default=0.5) – The exponential moving average (EMA) decay parameter.
GradVac_group_type (int, default=0) – The parameter granularity (0: whole_model; 1: all_layer; 2: all_matrix).
Warning
GradVac is not supported by representation gradients, i.e.,
rep_grad
must beFalse
.- init_param(self)¶
- backward(self, losses, **kwargs)¶
- class IMTL¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Impartial Multi-task Learning (IMTL).
This method is proposed in Towards Impartial Multi-task Learning (ICLR 2021) and implemented by us.
- init_param(self)¶
Define and initialize some trainable parameters required by specific weighting methods.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class CAGrad[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Conflict-Averse Gradient descent (CAGrad).
This method is proposed in Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) and implemented by modifying from the official PyTorch implementation.
- Parameters
calpha (float, default=0.5) – A hyperparameter that controls the convergence rate.
rescale ({0, 1, 2}, default=1) – The type of the gradient rescaling.
Warning
CAGrad is not supported by representation gradients, i.e.,
rep_grad
must beFalse
.- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class Nash_MTL¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Nash-MTL.
This method is proposed in Multi-Task Learning as a Bargaining Game (ICML 2022) and implemented by modifying from the official PyTorch implementation.
- Parameters
update_weights_every (int, default=1) – Period of weights update.
optim_niter (int, default=20) – The max iteration of optimization solver.
max_norm (float, default=1.0) – The max norm of the gradients.
Warning
Nash_MTL is not supported by representation gradients, i.e.,
rep_grad
must beFalse
.- init_param(self)¶
Define and initialize some trainable parameters required by specific weighting methods.
- solve_optimization(self, gtg: numpy.array)¶
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class RLW¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Random Loss Weighting (RLW).
This method is proposed in Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning (TMLR 2022) and implemented by us.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class MoCo¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
MoCo.
This method is proposed in Mitigating Gradient Bias in Multi-objective Learning: A Provably Convergent Approach (ICLR 2023) and implemented based on the author’ sharing code (Heshan Fernando: fernah@rpi.edu).
- Parameters
MoCo_beta (float, default=0.5) – The learning rate of y.
MoCo_beta_sigma (float, default=0.5) – The decay rate of MoCo_beta.
MoCo_gamma (float, default=0.1) – The learning rate of lambd.
MoCo_gamma_sigma (float, default=0.5) – The decay rate of MoCo_gamma.
MoCo_rho (float, default=0) – The ell_2 regularization parameter of lambda’s update.
Warning
MoCo is not supported by representation gradients, i.e.,
rep_grad
must beFalse
.- init_param(self)¶
Define and initialize some trainable parameters required by specific weighting methods.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
- class Aligned_MTL[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Aligned-MTL.
This method is proposed in Independent Component Alignment for Multi-Task Learning (CVPR 2023) and implemented by modifying from the official PyTorch implementation.
- backward(self, losses, **kwargs)¶
- Parameters
losses (list) – A list of losses of each task.
kwargs (dict) – A dictionary of hyperparameters of weighting methods.
LibMTL.loss
¶
LibMTL.utils
¶
- set_random_seed(seed)[source]¶
Set the random seed for reproducibility.
- Parameters
seed (int, default=0) – The random seed.
- set_device(gpu_id)[source]¶
Set the device where model and data will be allocated.
- Parameters
gpu_id (str, default='0') – The id of gpu.
- count_parameters(model)[source]¶
Calculate the number of parameters for a model.
- Parameters
model (torch.nn.Module) – A neural network module.
- count_improvement(base_result, new_result, weight)[source]¶
Calculate the improvement between two results as
\[\Delta_{\mathrm{p}}=100\%\times \frac{1}{T}\sum_{t=1}^T \frac{1}{M_t}\sum_{m=1}^{M_t}\frac{(-1)^{w_{t,m}}(B_{t,m}-N_{t,m})}{N_{t,m}}.\]- Parameters
base_result (dict) – A dictionary of scores of all metrics of all tasks.
new_result (dict) – The same structure with
base_result
.weight (dict) – The same structure with
base_result
while each element is binary integer representing whether higher or lower score is better.
- Returns
The improvement between
new_result
andbase_result
.- Return type
float
Examples:
base_result = {'A': [96, 98], 'B': [0.2]} new_result = {'A': [93, 99], 'B': [0.5]} weight = {'A': [1, 0], 'B': [1]} print(count_improvement(base_result, new_result, weight))
LibMTL.config
¶
LibMTL.metrics
¶
- class AbsMetric[source]¶
Bases:
object
An abstract class for the performance metrics of a task.
- record¶
A list of the metric scores in every iteration.
- Type
list
- bs¶
A list of the number of data in every iteration.
- Type
list
- property update_fun(self, pred, gt)[source]¶
Calculate the metric scores in every iteration and update
record
.- Parameters
pred (torch.Tensor) – The prediction tensor.
gt (torch.Tensor) – The ground-truth tensor.