Source code for LibMTL.architecture.DSelect_k

import torch, math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.architecture.MMoE import MMoE

[docs]class DSelect_k(MMoE): r"""DSelect-k. This method is proposed in `DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) <>`_ \ and implemented by modifying from the `official TensorFlow implementation <>`_. Args: img_size (list): The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224. num_experts (int): The number of experts shared by all the tasks. Each expert is an encoder network. num_nonzeros (int): The number of selected experts. kgamma (float, default=1.0): A scaling parameter for the smooth-step function. """ def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs): super(DSelect_k, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs) self._num_nonzeros = self.kwargs['num_nonzeros'] self._gamma = self.kwargs['kgamma'] self._num_binary = math.ceil(math.log2(self.num_experts)) self._power_of_2 = (self.num_experts == 2 ** self._num_binary) self._z_logits = nn.ModuleDict({task: nn.Linear(self.input_size, self._num_nonzeros*self._num_binary) for task in self.task_name}) self._w_logits = nn.ModuleDict({task: nn.Linear(self.input_size, self._num_nonzeros) for task in self.task_name}) # initialization for param in self._z_logits.parameters():, self._gamma/100) for param in self._w_logits.parameters():, 0.05) binary_matrix = np.array([list(np.binary_repr(val, width=self._num_binary)) \ for val in range(self.num_experts)]).astype(bool) self._binary_codes = torch.from_numpy(binary_matrix).to(self.device).unsqueeze(0) self.gate_specific = None def _smooth_step_fun(self, t, gamma=1.0): return torch.where(t<=-gamma/2, torch.zeros_like(t, device=t.device), torch.where(t>=gamma/2, torch.ones_like(t, device=t.device), (-2/(gamma**3))*(t**3) + (3/(2*gamma))*t + 1/2)) def _entropy_reg_loss(self, inputs): loss = -(inputs*torch.log(inputs+1e-6)).sum() * 1e-6 if not self._power_of_2: loss += (1/inputs.sum(-1)).sum() loss.backward(retain_graph=True)
[docs] def forward(self, inputs, task_name=None): experts_shared_rep = torch.stack([e(inputs) for e in self.experts_shared]) out = {} for task in self.task_name: if task_name is not None and task != task_name: continue sample_logits = self._z_logits[task](torch.flatten(inputs, start_dim=1)) sample_logits = sample_logits.reshape(-1, self._num_nonzeros, 1, self._num_binary) smooth_step_activations = self._smooth_step_fun(sample_logits) selector_outputs = torch.where(self._binary_codes.unsqueeze(0), smooth_step_activations, 1 - smooth_step_activations).prod(3) selector_weights = F.softmax(self._w_logits[task](torch.flatten(inputs, start_dim=1)), dim=1) expert_weights = torch.einsum('ij, ij... -> i...', selector_weights, selector_outputs) gate_rep = torch.einsum('ij, ji... -> i...', expert_weights, experts_shared_rep) gate_rep = self._prepare_rep(gate_rep, task, same_rep=False) out[task] = self.decoders[task](gate_rep) if self._entropy_reg_loss(selector_outputs) return out