LibMTL.weighting.CAGrad
¶
- class CAGrad[source]¶
Bases:
LibMTL.weighting.abstract_weighting.AbsWeighting
Conflict-Averse Gradient descent (CAGrad).
This method is proposed in Conflict-Averse Gradient Descent for Multi-task learning (NeurIPS 2021) and implemented by modifying from the official PyTorch implementation.
- Parameters
calpha (float, default=0.5) – A hyperparameter that controls the convergence rate.
rescale ({0, 1, 2}, default=1) – The type of the gradient rescaling.
Warning
CAGrad is not supported by representation gradients, i.e.,
rep_grad
must beFalse
.