import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.architecture.abstract_arch import AbsArchitecture

class _transform_resnet_ltb(nn.Module):
    def __init__(self, encoder_list, task_name, device):
        super(_transform_resnet_ltb, self).__init__()
        self.task_name = task_name
        self.task_num = len(task_name)
        self.device = device
        # self.epochs = epochs
        self.resnet_conv = nn.ModuleDict({task: nn.Sequential(encoder_list[tn].conv1, encoder_list[tn].bn1, 
                                                              encoder_list[tn].relu, encoder_list[tn].maxpool) for tn, task in enumerate(self.task_name)})
        self.resnet_layer = nn.ModuleDict({})
        for i in range(4):
            self.resnet_layer[str(i)] = nn.ModuleList([])
            for tn in range(self.task_num):
                encoder = encoder_list[tn]
        self.alpha = nn.Parameter(torch.ones(6, self.task_num, self.task_num))
    def forward(self, inputs, epoch, epochs):
        if epoch < epochs/100: # warmup
            alpha = torch.ones(6, self.task_num, self.task_num).to(self.device)
            tau = epochs/20 / np.sqrt(epoch+1) # tau decay
            alpha = F.gumbel_softmax(self.alpha, dim=-1, tau=tau, hard=True)

        ss_rep = {i: [0]*self.task_num for i in range(5)}
        for i in range(5): # i: layer idx
            for tn, task in enumerate(self.task_name): # tn: task idx
                if i == 0:
                    ss_rep[i][tn] = self.resnet_conv[task](inputs)
                    child_rep = sum([alpha[i,tn,j]*ss_rep[i-1][j] for j in range(self.task_num)]) # j: module idx
                    ss_rep[i][tn] = self.resnet_layer[str(i-1)][tn](child_rep)
        return ss_rep[4]

[docs]class LTB(AbsArchitecture): r"""Learning To Branch (LTB). This method is proposed in `Learning to Branch for Multi-Task Learning (ICML 2020) <>`_ \ and implemented by us. .. warning:: - :class:`LTB` does not work with multi-input problems, i.e., ``multi_input`` must be ``False``. - :class:`LTB` is only supported by ResNet-based encoders. """ def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): super(LTB, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) if self.multi_input: raise ValueError('No support LTB for multiple inputs MTL problem') self.encoder = nn.ModuleList([self.encoder_class() for _ in range(self.task_num)]) self.encoder = _transform_resnet_ltb(self.encoder, task_name, device)
[docs] def forward(self, inputs, task_name=None): r""" Args: inputs (torch.Tensor): The input data. task_name (str, default=None): The task name corresponding to ``inputs`` if ``multi_input`` is ``True``. Returns: dict: A dictionary of name-prediction pairs of type (:class:`str`, :class:`torch.Tensor`). """ out = {} s_rep = self.encoder(inputs, self.epoch, self.epochs) same_rep = True if not isinstance(s_rep, list) and not self.multi_input else False for tn, task in enumerate(self.task_name): if task_name is not None and task != task_name: continue ss_rep = s_rep[tn] if isinstance(s_rep, list) else s_rep ss_rep = self._prepare_rep(ss_rep, task, same_rep) out[task] = self.decoders[task](ss_rep) return out