Source code for LibMTL.weighting.PCGrad

import torch, random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

[docs]class PCGrad(AbsWeighting): r"""Project Conflicting Gradients (PCGrad). This method is proposed in `Gradient Surgery for Multi-Task Learning (NeurIPS 2020) <https://papers.nips.cc/paper/2020/hash/3fe78a8acf5fda99de95303940a2420c-Abstract.html>`_ \ and implemented by us. .. warning:: PCGrad is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``. """ def __init__(self): super(PCGrad, self).__init__()
[docs] def backward(self, losses, **kwargs): batch_weight = np.ones(len(losses)) if self.rep_grad: raise ValueError('No support method PCGrad with representation gradients (rep_grad=True)') else: self._compute_grad_dim() grads = self._compute_grad(losses, mode='backward') # [task_num, grad_dim] pc_grads = grads.clone() for tn_i in range(self.task_num): task_index = list(range(self.task_num)) random.shuffle(task_index) for tn_j in task_index: g_ij = torch.dot(pc_grads[tn_i], grads[tn_j]) if g_ij < 0: pc_grads[tn_i] -= g_ij * grads[tn_j] / (grads[tn_j].norm().pow(2)+1e-8) batch_weight[tn_j] -= (g_ij/(grads[tn_j].norm().pow(2)+1e-8)).item() new_grads = pc_grads.sum(0) self._reset_grad(new_grads) return batch_weight