Source code for LibMTL.weighting.GradNorm

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

[docs]class GradNorm(AbsWeighting): r"""Gradient Normalization (GradNorm). This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) <http://proceedings.mlr.press/v80/chen18a/chen18a.pdf>`_ \ and implemented by us. Args: alpha (float, default=1.5): The strength of the restoring force which pulls tasks back to a common training rate. """ def __init__(self): super(GradNorm, self).__init__()
[docs] def init_param(self): self.loss_scale = nn.Parameter(torch.tensor([1.0]*self.task_num, device=self.device))
[docs] def backward(self, losses, **kwargs): alpha = kwargs['alpha'] if self.epoch >= 1: loss_scale = self.task_num * F.softmax(self.loss_scale, dim=-1) grads = self._get_grads(losses, mode='backward') if self.rep_grad: per_grads, grads = grads[0], grads[1] G_per_loss = torch.norm(loss_scale.unsqueeze(1)*grads, p=2, dim=-1) G = G_per_loss.mean(0) L_i = torch.Tensor([losses[tn].item()/self.train_loss_buffer[tn, 0] for tn in range(self.task_num)]).to(self.device) r_i = L_i/L_i.mean() constant_term = (G*(r_i**alpha)).detach() L_grad = (G_per_loss-constant_term).abs().sum(0) L_grad.backward() loss_weight = loss_scale.detach().clone() if self.rep_grad: self._backward_new_grads(loss_weight, per_grads=per_grads) else: self._backward_new_grads(loss_weight, grads=grads) return loss_weight.cpu().numpy() else: loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum() loss.backward() return np.ones(self.task_num)