Source code for LibMTL.weighting.GradVac

import torch, random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting


[docs]class GradVac(AbsWeighting): r"""Gradient Vaccine (GradVac). This method is proposed in `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) <https://openreview.net/forum?id=F1vEjWK-lH_>`_ \ and implemented by us. Args: GradVac_beta (float, default=0.5): The exponential moving average (EMA) decay parameter. GradVac_group_type (int, default=0): The parameter granularity (0: whole_model; 1: all_layer; 2: all_matrix). .. warning:: GradVac is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. """ def __init__(self): super(GradVac, self).__init__()
[docs] def init_param(self): self.step = 0
def _init_rho(self, group_type): if group_type == 0: # whole_model self.k_idx = [-1] elif group_type == 1: # all_layer self.k_idx = [] for module in self.encoder.modules(): if len(module._modules.items()) == 0 and len(module._parameters) > 0: self.k_idx.append(sum([w.data.numel() for w in module.parameters()])) elif group_type == 2: # all_matrix self._compute_grad_dim() self.k_idx = self.grad_index else: raise ValueError self.rho_T = torch.zeros(self.task_num, self.task_num, len(self.k_idx)).to(self.device)
[docs] def backward(self, losses, **kwargs): beta = kwargs['GradVac_beta'] group_type = kwargs['GradVac_group_type'] if self.step == 0: self._init_rho(group_type) if self.rep_grad: raise ValueError('No support method GradVac with representation gradients (rep_grad=True)') else: self._compute_grad_dim() grads = self._compute_grad(losses, mode='backward') # [task_num, grad_dim] batch_weight = np.ones(len(losses)) pc_grads = grads.clone() for tn_i in range(self.task_num): task_index = list(range(self.task_num)) task_index.remove(tn_i) random.shuffle(task_index) for tn_j in task_index: for k in range(len(self.k_idx)): beg, end = sum(self.k_idx[:k]), sum(self.k_idx[:k+1]) if end == -1: end = grads.size()[-1] rho_ijk = torch.dot(pc_grads[tn_i,beg:end], grads[tn_j,beg:end]) / (pc_grads[tn_i,beg:end].norm()*grads[tn_j,beg:end].norm()+1e-8) if rho_ijk < self.rho_T[tn_i, tn_j, k]: w = pc_grads[tn_i,beg:end].norm()*(self.rho_T[tn_i,tn_j,k]*(1-rho_ijk**2).sqrt()-rho_ijk*(1-self.rho_T[tn_i,tn_j,k]**2).sqrt())/(grads[tn_j,beg:end].norm()*(1-self.rho_T[tn_i,tn_j,k]**2).sqrt()+1e-8) pc_grads[tn_i,beg:end] += grads[tn_j,beg:end]*w # batch_weight[tn_j] += w.item() self.rho_T[tn_i,tn_j,k] = (1-beta)*self.rho_T[tn_i,tn_j,k] + beta*rho_ijk new_grads = pc_grads.sum(0) self._reset_grad(new_grads) self.step += 1 return batch_weight