import torch, random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from LibMTL.weighting.abstract_weighting import AbsWeighting
[docs]class RLW(AbsWeighting):
r"""Random Loss Weighting (RLW).
This method is proposed in `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning (TMLR 2022) <https://openreview.net/forum?id=jjtFD8A1Wx>`_ \
and implemented by us.
"""
def __init__(self):
super(RLW, self).__init__()
[docs] def backward(self, losses, **kwargs):
batch_weight = F.softmax(torch.randn(self.task_num), dim=-1).to(self.device)
loss = torch.mul(losses, batch_weight).sum()
loss.backward()
return batch_weight.detach().cpu().numpy()