LibMTL.architecture.DSelect_k

class DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]

Bases: LibMTL.architecture.MMoE.MMoE

DSelect-k.

This method is proposed in DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) and implemented by modifying from the official TensorFlow implementation.

Parameters
  • img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.

  • num_experts (int) – The number of experts shared by all the tasks. Each expert is an encoder network.

  • num_nonzeros (int) – The number of selected experts.

  • kgamma (float, default=1.0) – A scaling parameter for the smooth-step function.

forward(self, inputs, task_name=None)[source]