LibMTL.architecture.MMoE¶
- class MMoE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitectureMulti-gate Mixture-of-Experts (MMoE).
This method is proposed in Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts (KDD 2018) and implemented by us.
- Parameters:
img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.
num_experts (int) – The number of experts shared for all tasks. Each expert is an encoder network.
- forward(inputs, task_name=None)[source]¶
- Parameters:
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputsifmulti_inputisTrue.
- Returns:
A dictionary of name-prediction pairs of type (
str,torch.Tensor).- Return type:
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.