Source code for LibMTL.weighting.DWA

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

[docs]class DWA(AbsWeighting): r"""Dynamic Weight Average (DWA). This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) <https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf>`_ \ and implemented by modifying from the `official PyTorch implementation <https://github.com/lorenmt/mtan>`_. Args: T (float, default=2.0): The softmax temperature. """ def __init__(self): super(DWA, self).__init__()
[docs] def backward(self, losses, **kwargs): T = kwargs['T'] if self.epoch > 1: w_i = torch.Tensor(self.train_loss_buffer[:,self.epoch-1]/self.train_loss_buffer[:,self.epoch-2]).to(self.device) batch_weight = self.task_num*F.softmax(w_i/T, dim=-1) else: batch_weight = torch.ones_like(losses).to(self.device) loss = torch.mul(losses, batch_weight).sum() loss.backward() return batch_weight.detach().cpu().numpy()