Source code for LibMTL.weighting.DB_MTL

import torch, random, copy
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

[docs]class DB_MTL(AbsWeighting): def __init__(self): super(DB_MTL, self).__init__()
[docs] def init_param(self): self.step = 0 self._compute_grad_dim() self.grad_buffer = torch.zeros(self.task_num, self.grad_dim).to(self.device)
[docs] def backward(self, losses, **kwargs): self.step += 1 beta = kwargs['DB_beta'] beta_sigma = kwargs['DB_beta_sigma'] batch_weight = np.ones(len(losses)) if self.rep_grad: raise ValueError('No support method DB_MTL with representation gradients (rep_grad=True)') else: self._compute_grad_dim() batch_grads = self._compute_grad(torch.log(losses+1e-8), mode='backward') # [task_num, grad_dim] self.grad_buffer = batch_grads + (beta/self.step**beta_sigma) * (self.grad_buffer - batch_grads) u_grad = self.grad_buffer.norm(dim=-1) alpha = u_grad.max() / (u_grad + 1e-8) new_grads = sum([alpha[i] * self.grad_buffer[i] for i in range(self.task_num)]) self._reset_grad(new_grads) return batch_weight