LibMTL.architecture.abstract_arch¶
- class AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
torch.nn.ModuleAn abstract class for MTL architectures.
- Parameters:
task_name (list) – A list of strings for all tasks.
encoder_class (class) – A neural network class.
decoders (dict) – A dictionary of name-decoder pairs of type (
str,torch.nn.Module).rep_grad (bool) – If
True, the gradient of the representation for each task can be computed.multi_input (bool) – Is
Trueif each task has its own input data, otherwise isFalse.device (torch.device) – The device where model and data will be allocated.
kwargs (dict) – A dictionary of hyperparameters of architectures.
- forward(inputs, task_name=None)[source]¶
- Parameters:
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputsifmulti_inputisTrue.
- Returns:
A dictionary of name-prediction pairs of type (
str,torch.Tensor).- Return type:
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.