LibMTL.architecture.abstract_arch
¶
- class AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
torch.nn.Module
An abstract class for MTL architectures.
- Parameters
task_name (list) – A list of strings for all tasks.
encoder_class (class) – A neural network class.
decoders (dict) – A dictionary of name-decoder pairs of type (
str
,torch.nn.Module
).rep_grad (bool) – If
True
, the gradient of the representation for each task can be computed.multi_input (bool) – Is
True
if each task has its own input data, otherwise isFalse
.device (torch.device) – The device where model and data will be allocated.
kwargs (dict) – A dictionary of hyperparameters of architectures.
- forward(self, inputs, task_name=None)[source]¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.