LibMTL.weighting.GradVac

class GradVac[source]

Bases: LibMTL.weighting.abstract_weighting.AbsWeighting

Gradient Vaccine (GradVac).

This method is proposed in Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) and implemented by us.

Parameters
  • GradVac_beta (float, default=0.5) – The exponential moving average (EMA) decay parameter.

  • GradVac_group_type (int, default=0) – The parameter granularity (0: whole_model; 1: all_layer; 2: all_matrix).

Warning

GradVac is not supported by representation gradients, i.e., rep_grad must be False.

init_param(self)[source]
backward(self, losses, **kwargs)[source]