import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from LibMTL.weighting.abstract_weighting import AbsWeighting
[docs]class IMTL(AbsWeighting):
r"""Impartial Multi-task Learning (IMTL).
This method is proposed in `Towards Impartial Multi-task Learning (ICLR 2021) <https://openreview.net/forum?id=IMPnRXEWpvr>`_ \
and implemented by us.
"""
def __init__(self):
super(IMTL, self).__init__()
[docs] def init_param(self):
self.loss_scale = nn.Parameter(torch.tensor([0.0]*self.task_num, device=self.device))
[docs] def backward(self, losses, **kwargs):
losses = self.loss_scale.exp()*losses - self.loss_scale
grads = self._get_grads(losses, mode='backward')
if self.rep_grad:
per_grads, grads = grads[0], grads[1]
grads_unit = grads/torch.norm(grads, p=2, dim=-1, keepdim=True)
D = grads[0:1].repeat(self.task_num-1, 1) - grads[1:]
U = grads_unit[0:1].repeat(self.task_num-1, 1) - grads_unit[1:]
alpha = torch.matmul(torch.matmul(grads[0], U.t()), torch.inverse(torch.matmul(D, U.t())))
alpha = torch.cat((1-alpha.sum().unsqueeze(0), alpha), dim=0)
if self.rep_grad:
self._backward_new_grads(alpha, per_grads=per_grads)
else:
self._backward_new_grads(alpha, grads=grads)
return alpha.detach().cpu().numpy()