LibMTL.weighting.CAGrad

class CAGrad[source]

Bases: LibMTL.weighting.abstract_weighting.AbsWeighting

Conflict-Averse Gradient descent (CAGrad).

This method is proposed in Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) and implemented by modifying from the official PyTorch implementation.

Parameters
  • calpha (float, default=0.5) – A hyperparameter that controls the convergence rate.

  • rescale ({0, 1, 2}, default=1) – The type of the gradient rescaling.

Warning

CAGrad is not supported by representation gradients, i.e., rep_grad must be False.

backward(self, losses, **kwargs)[source]
Parameters
  • losses (list) – A list of losses of each task.

  • kwargs (dict) – A dictionary of hyperparameters of weighting methods.