Introduction¶
LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). This library has the following three characteristics.
Unified:
LibMTLprovides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms.Comprehensive:
LibMTLsupports many state-of-the-art MTL methods including 8 architectures and 16 optimization strategies. Meanwhile,LibMTLprovides a fair comparison of several benchmark datasets covering different fields.Extensible:
LibMTLfollows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel loss weighting strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support ofLibMTL.
Supported Algorithms¶
LibMTL currently supports the following algorithms:
| Optimization Strategies | Venues | Arguments |
|---|---|---|
| Equal Weighting (EW) | - | --weighting EW |
| Gradient Normalization (GradNorm) | ICML 2018 | --weighting GradNorm |
| Uncertainty Weights (UW) | CVPR 2018 | --weighting UW |
| MGDA (official code) | NeurIPS 2018 | --weighting MGDA |
| Dynamic Weight Average (DWA) (official code) | CVPR 2019 | --weighting DWA |
| Geometric Loss Strategy (GLS) | CVPR 2019 Workshop | --weighting GLS |
| Projecting Conflicting Gradient (PCGrad) | NeurIPS 2020 | --weighting PCGrad |
| Gradient sign Dropout (GradDrop) | NeurIPS 2020 | --weighting GradDrop |
| Impartial Multi-Task Learning (IMTL) | ICLR 2021 | --weighting IMTL |
| Gradient Vaccine (GradVac) | ICLR 2021 | --weighting GradVac |
| Conflict-Averse Gradient descent (CAGrad) (official code) | NeurIPS 2021 | --weighting CAGrad |
| Nash-MTL (official code) | ICML 2022 | --weighting Nash_MTL |
| Random Loss Weighting (RLW) | TMLR 2022 | --weighting RLW |
| MoCo | ICLR 2023 | --weighting MoCo |
| Aligned-MTL (official code) | CVPR 2023 | --weighting Aligned_MTL |
| DB-MTL | arXiv | --weighting DB_MTL |
| Architectures | Venues | Arguments |
|---|---|---|
| Hard Parameter Sharing (HPS) | ICML 1993 | --arch HPS |
| Cross-stitch Networks (Cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| Multi-gate Mixture-of-Experts (MMoE) | KDD 2018 | --arch MMoE |
| Multi-Task Attention Network (MTAN) (official code) | CVPR 2019 | --arch MTAN |
| Customized Gate Control (CGC), Progressive Layered Extraction (PLE) | ACM RecSys 2020 | --arch CGC, --arch PLE |
| Learning to Branch (LTB) | ICML 2020 | --arch LTB |
| DSelect-k (official code) | NeurIPS 2021 | --arch DSelect_k |
Supported Benchmark Datasets¶
| Datasets | Problems | Task Number | Tasks | multi-input | Supported Backbone |
|---|---|---|---|---|---|
| NYUv2 | Scene Understanding | 3 | Semantic Segmentation+ Depth Estimation+ Surface Normal Prediction |
✘ | ResNet50/ SegNet |
| Office-31 | Image Recognition | 3 | Classification | ✓ | ResNet18 |
| Office-Home | Image Recognition | 4 | Classification | ✓ | ResNet18 |
| QM9 | Molecular Property Prediction | 11 (default) | Regression | ✘ | GNN |
| PAWS-X | Paraphrase Identification | 4 (default) | Classification | ✓ | Bert |
Citation¶
If you find LibMTL useful for your research or development, please cite the following:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
}
Contributors¶
LibMTL is developed and maintained by Baijiong Lin.
Contact Us¶
If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to bj.lin.email@gmail.com.
Acknowledgements¶
We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, mtan, MTL, nash-mtl, pytorch_geometric, and xtreme.