LibMTL.weighting.MGDA

class MGDA[source]

Bases: LibMTL.weighting.abstract_weighting.AbsWeighting

Multiple Gradient Descent Algorithm (MGDA).

This method is proposed in Multi-Task Learning as Multi-Objective Optimization (NeurIPS 2018) and implemented by modifying from the official PyTorch implementation.

Parameters

mgda_gn ({'none', 'l2', 'loss', 'loss+'}, default='none') – The type of gradient normalization.

backward(self, losses, **kwargs)[source]
Parameters
  • losses (list) – A list of losses of each task.

  • kwargs (dict) – A dictionary of hyperparameters of weighting methods.