import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from LibMTL.weighting.abstract_weighting import AbsWeighting
[docs]class GradNorm(AbsWeighting):
r"""Gradient Normalization (GradNorm).
This method is proposed in `GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks (ICML 2018) <http://proceedings.mlr.press/v80/chen18a/chen18a.pdf>`_ \
and implemented by us.
Args:
alpha (float, default=1.5): The strength of the restoring force which pulls tasks back to a common training rate.
"""
def __init__(self):
super(GradNorm, self).__init__()
[docs] def init_param(self):
self.loss_scale = nn.Parameter(torch.tensor([1.0]*self.task_num, device=self.device))
[docs] def backward(self, losses, **kwargs):
alpha = kwargs['alpha']
if self.epoch >= 1:
loss_scale = self.task_num * F.softmax(self.loss_scale, dim=-1)
grads = self._get_grads(losses, mode='backward')
if self.rep_grad:
per_grads, grads = grads[0], grads[1]
G_per_loss = torch.norm(loss_scale.unsqueeze(1)*grads, p=2, dim=-1)
G = G_per_loss.mean(0)
L_i = torch.Tensor([losses[tn].item()/self.train_loss_buffer[tn, 0] for tn in range(self.task_num)]).to(self.device)
r_i = L_i/L_i.mean()
constant_term = (G*(r_i**alpha)).detach()
L_grad = (G_per_loss-constant_term).abs().sum(0)
L_grad.backward()
loss_weight = loss_scale.detach().clone()
if self.rep_grad:
self._backward_new_grads(loss_weight, per_grads=per_grads)
else:
self._backward_new_grads(loss_weight, grads=grads)
return loss_weight.cpu().numpy()
else:
loss = torch.mul(losses, torch.ones_like(losses).to(self.device)).sum()
loss.backward()
return np.ones(self.task_num)