import torch, sys, random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
[docs]class AbsWeighting(nn.Module):
r"""An abstract class for weighting strategies.
"""
def __init__(self):
super(AbsWeighting, self).__init__()
[docs] def init_param(self):
r"""Define and initialize some trainable parameters required by specific weighting methods.
"""
pass
def _compute_grad_dim(self):
self.grad_index = []
for param in self.get_share_params():
self.grad_index.append(param.data.numel())
self.grad_dim = sum(self.grad_index)
def _grad2vec(self):
grad = torch.zeros(self.grad_dim)
count = 0
for param in self.get_share_params():
if param.grad is not None:
beg = 0 if count == 0 else sum(self.grad_index[:count])
end = sum(self.grad_index[:(count+1)])
grad[beg:end] = param.grad.data.view(-1)
count += 1
return grad
def _compute_grad(self, losses, mode, rep_grad=False):
'''
mode: backward, autograd
'''
if not rep_grad:
grads = torch.zeros(self.task_num, self.grad_dim).to(self.device)
for tn in range(self.task_num):
if mode == 'backward':
losses[tn].backward(retain_graph=True) if (tn+1)!=self.task_num else losses[tn].backward()
grads[tn] = self._grad2vec()
elif mode == 'autograd':
grad = list(torch.autograd.grad(losses[tn], self.get_share_params(), retain_graph=True))
grads[tn] = torch.cat([g.view(-1) for g in grad])
else:
raise ValueError('No support {} mode for gradient computation')
self.zero_grad_share_params()
else:
if not isinstance(self.rep, dict):
grads = torch.zeros(self.task_num, *self.rep.size()).to(self.device)
else:
grads = [torch.zeros(*self.rep[task].size()) for task in self.task_name]
for tn, task in enumerate(self.task_name):
if mode == 'backward':
losses[tn].backward(retain_graph=True) if (tn+1)!=self.task_num else losses[tn].backward()
grads[tn] = self.rep_tasks[task].grad.data.clone()
return grads
def _reset_grad(self, new_grads):
count = 0
for param in self.get_share_params():
if param.grad is not None:
beg = 0 if count == 0 else sum(self.grad_index[:count])
end = sum(self.grad_index[:(count+1)])
param.grad.data = new_grads[beg:end].contiguous().view(param.data.size()).data.clone()
count += 1
def _get_grads(self, losses, mode='backward'):
r"""This function is used to return the gradients of representations or shared parameters.
If ``rep_grad`` is ``True``, it returns a list with two elements. The first element is \
the gradients of the representations with the size of [task_num, batch_size, rep_size]. \
The second element is the resized gradients with size of [task_num, -1], which means \
the gradient of each task is resized as a vector.
If ``rep_grad`` is ``False``, it returns the gradients of the shared parameters with size \
of [task_num, -1], which means the gradient of each task is resized as a vector.
"""
if self.rep_grad:
per_grads = self._compute_grad(losses, mode, rep_grad=True)
if not isinstance(self.rep, dict):
grads = per_grads.reshape(self.task_num, self.rep.size()[0], -1).sum(1)
else:
try:
grads = torch.stack(per_grads).sum(1).view(self.task_num, -1)
except:
raise ValueError('The representation dimensions of different tasks must be consistent')
return [per_grads, grads]
else:
self._compute_grad_dim()
grads = self._compute_grad(losses, mode)
return grads
def _backward_new_grads(self, batch_weight, per_grads=None, grads=None):
r"""This function is used to reset the gradients and make a backward.
Args:
batch_weight (torch.Tensor): A tensor with size of [task_num].
per_grad (torch.Tensor): It is needed if ``rep_grad`` is True. The gradients of the representations.
grads (torch.Tensor): It is needed if ``rep_grad`` is False. The gradients of the shared parameters.
"""
if self.rep_grad:
if not isinstance(self.rep, dict):
# transformed_grad = torch.einsum('i, i... -> ...', batch_weight, per_grads)
transformed_grad = sum([batch_weight[i] * per_grads[i] for i in range(self.task_num)])
self.rep.backward(transformed_grad)
else:
for tn, task in enumerate(self.task_name):
rg = True if (tn+1)!=self.task_num else False
self.rep[task].backward(batch_weight[tn]*per_grads[tn], retain_graph=rg)
else:
# new_grads = torch.einsum('i, i... -> ...', batch_weight, grads)
new_grads = sum([batch_weight[i] * grads[i] for i in range(self.task_num)])
self._reset_grad(new_grads)
@property
[docs] def backward(self, losses, **kwargs):
r"""
Args:
losses (list): A list of losses of each task.
kwargs (dict): A dictionary of hyperparameters of weighting methods.
"""
pass