Source code for LibMTL.weighting.Nash_MTL

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

try:
    import cvxpy as cp
except ModuleNotFoundError:
    from pip._internal import main as pip
    pip(['install', '--user', 'cvxpy'])
    import cvxpy as cp

[docs]class Nash_MTL(AbsWeighting): r"""Nash-MTL. This method is proposed in `Multi-Task Learning as a Bargaining Game (ICML 2022) <https://proceedings.mlr.press/v162/navon22a/navon22a.pdf>`_ \ and implemented by modifying from the `official PyTorch implementation <https://github.com/AvivNavon/nash-mtl>`_. Args: update_weights_every (int, default=1): Period of weights update. optim_niter (int, default=20): The max iteration of optimization solver. max_norm (float, default=1.0): The max norm of the gradients. .. warning:: Nash_MTL is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. """ def __init__(self): super(Nash_MTL, self).__init__()
[docs] def init_param(self): self.step = 0 self.prvs_alpha_param = None self.init_gtg = np.eye(self.task_num) self.prvs_alpha = np.ones(self.task_num, dtype=np.float32) self.normalization_factor = np.ones((1,))
def _stop_criteria(self, gtg, alpha_t): return ( (self.alpha_param.value is None) or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) or ( np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6 ) )
[docs] def solve_optimization(self, gtg: np.array): self.G_param.value = gtg self.normalization_factor_param.value = self.normalization_factor alpha_t = self.prvs_alpha for _ in range(self.optim_niter): self.alpha_param.value = alpha_t self.prvs_alpha_param.value = alpha_t try: self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) except: self.alpha_param.value = self.prvs_alpha_param.value if self._stop_criteria(gtg, alpha_t): break alpha_t = self.alpha_param.value if alpha_t is not None: self.prvs_alpha = alpha_t return self.prvs_alpha
def _calc_phi_alpha_linearization(self): G_prvs_alpha = self.G_param @ self.prvs_alpha_param prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param) return phi_alpha def _init_optim_problem(self): self.alpha_param = cp.Variable(shape=(self.task_num,), nonneg=True) self.prvs_alpha_param = cp.Parameter( shape=(self.task_num,), value=self.prvs_alpha ) self.G_param = cp.Parameter( shape=(self.task_num, self.task_num), value=self.init_gtg ) self.normalization_factor_param = cp.Parameter( shape=(1,), value=np.array([1.0]) ) self.phi_alpha = self._calc_phi_alpha_linearization() G_alpha = self.G_param @ self.alpha_param constraint = [] for i in range(self.task_num): constraint.append( -cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i]) <= 0 ) obj = cp.Minimize( cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param ) self.prob = cp.Problem(obj, constraint)
[docs] def backward(self, losses, **kwargs): self.update_weights_every = kwargs['update_weights_every'] self.optim_niter = kwargs['optim_niter'] self.max_norm = kwargs['max_norm'] if self.step == 0: self._init_optim_problem() if (self.step % self.update_weights_every) == 0: self.step += 1 if self.rep_grad: raise ValueError('No support method Nash_MTL with representation gradients (rep_grad=True)') else: self._compute_grad_dim() grads = self._compute_grad(losses, mode='autograd') GTG = torch.mm(grads, grads.t()) self.normalization_factor = torch.norm(GTG).detach().cpu().numpy().reshape((1,)) GTG = GTG / self.normalization_factor.item() alpha = self.solve_optimization(GTG.cpu().detach().numpy()) else: self.step += 1 alpha = self.prvs_alpha alpha = torch.from_numpy(alpha).to(torch.float32).to(self.device) torch.sum(alpha*losses).backward() if self.max_norm > 0: torch.nn.utils.clip_grad_norm_(self.get_share_params(), self.max_norm) return alpha.detach().cpu().numpy()