LibMTL.architecture.DSelect_k
¶
- class DSelect_k(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.MMoE.MMoE
DSelect-k.
This method is proposed in DSelect-k: Differentiable Selection in the Mixture of Experts with Applications to Multi-Task Learning (NeurIPS 2021) and implemented by modifying from the official TensorFlow implementation.
- Parameters
img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.
num_experts (int) – The number of experts shared by all the tasks. Each expert is an encoder network.
num_nonzeros (int) – The number of selected experts.
kgamma (float, default=1.0) – A scaling parameter for the smooth-step function.