Source code for LibMTL.weighting.MoCo

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

[docs]class MoCo(AbsWeighting): r"""MoCo. This method is proposed in `Mitigating Gradient Bias in Multi-objective Learning: A Provably Convergent Approach (ICLR 2023) <https://openreview.net/forum?id=dLAYGdKTi2>`_ \ and implemented based on the author' sharing code (Heshan Fernando: fernah@rpi.edu). Args: MoCo_beta (float, default=0.5): The learning rate of y. MoCo_beta_sigma (float, default=0.5): The decay rate of MoCo_beta. MoCo_gamma (float, default=0.1): The learning rate of lambd. MoCo_gamma_sigma (float, default=0.5): The decay rate of MoCo_gamma. MoCo_rho (float, default=0): The \ell_2 regularization parameter of lambda's update. .. warning:: MoCo is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. """ def __init__(self): super(MoCo, self).__init__()
[docs] def init_param(self): self._compute_grad_dim() self.step = 0 self.y = torch.zeros(self.task_num, self.grad_dim).to(self.device) self.lambd = (torch.ones([self.task_num, ]) / self.task_num).to(self.device)
[docs] def backward(self, losses, **kwargs): self.step += 1 beta, beta_sigma = kwargs['MoCo_beta'], kwargs['MoCo_beta_sigma'] gamma, gamma_sigma = kwargs['MoCo_gamma'], kwargs['MoCo_gamma_sigma'] rho = kwargs['MoCo_rho'] if self.rep_grad: raise ValueError('No support method MoCo with representation gradients (rep_grad=True)') else: self._compute_grad_dim() grads = self._compute_grad(losses, mode='backward') with torch.no_grad(): for tn in range(self.task_num): grads[tn] = grads[tn]/(grads[tn].norm()+1e-8)*losses[tn] self.y = self.y - (beta/self.step**beta_sigma) * (self.y - grads) self.lambd = F.softmax(self.lambd - (gamma/self.step**gamma_sigma) * (self.y@self.y.t()+rho*torch.eye(self.task_num).to(self.device))@self.lambd, -1) new_grads = self.y.t()@self.lambd self._reset_grad(new_grads) return self.lambd.detach().cpu().numpy()