Source code for LibMTL.architecture.MMoE

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.architecture.abstract_arch import AbsArchitecture

[docs]class MMoE(AbsArchitecture): r"""Multi-gate Mixture-of-Experts (MMoE). This method is proposed in `Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) <https://dl.acm.org/doi/10.1145/3219819.3220007>`_ \ and implemented by us. Args: img_size (list): The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224. num_experts (int): The number of experts shared for all tasks. Each expert is an encoder network. """ def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): super(MMoE, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) self.img_size = self.kwargs['img_size'] self.input_size = np.array(self.img_size, dtype=int).prod() self.num_experts = self.kwargs['num_experts'][0] self.experts_shared = nn.ModuleList([encoder_class() for _ in range(self.num_experts)]) self.gate_specific = nn.ModuleDict({task: nn.Sequential(nn.Linear(self.input_size, self.num_experts), nn.Softmax(dim=-1)) for task in self.task_name})
[docs] def forward(self, inputs, task_name=None): experts_shared_rep = torch.stack([e(inputs) for e in self.experts_shared]) out = {} for task in self.task_name: if task_name is not None and task != task_name: continue selector = self.gate_specific[task](torch.flatten(inputs, start_dim=1)) gate_rep = torch.einsum('ij..., ji -> j...', experts_shared_rep, selector) gate_rep = self._prepare_rep(gate_rep, task, same_rep=False) out[task] = self.decoders[task](gate_rep) return out
[docs] def get_share_params(self): return self.experts_shared.parameters()
[docs] def zero_grad_share_params(self): self.experts_shared.zero_grad()