LibMTL.weighting.Nash_MTL

class Nash_MTL[source]

Bases: LibMTL.weighting.abstract_weighting.AbsWeighting

Nash-MTL.

This method is proposed in Multi-Task Learning as a Bargaining Game (ICML 2022) and implemented by modifying from the official PyTorch implementation.

Parameters
  • update_weights_every (int, default=1) – Period of weights update.

  • optim_niter (int, default=20) – The max iteration of optimization solver.

  • max_norm (float, default=1.0) – The max norm of the gradients.

Warning

Nash_MTL is not supported by representation gradients, i.e., rep_grad must be False.

init_param(self)[source]

Define and initialize some trainable parameters required by specific weighting methods.

solve_optimization(self, gtg: numpy.array)[source]
backward(self, losses, **kwargs)[source]
Parameters
  • losses (list) – A list of losses of each task.

  • kwargs (dict) – A dictionary of hyperparameters of weighting methods.