LibMTL.architecture.abstract_arch

class AbsArchitecture(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]

Bases: torch.nn.Module

An abstract class for MTL architectures.

Parameters
  • task_name (list) – A list of strings for all tasks.

  • encoder_class (class) – A neural network class.

  • decoders (dict) – A dictionary of name-decoder pairs of type (str, torch.nn.Module).

  • rep_grad (bool) – If True, the gradient of the representation for each task can be computed.

  • multi_input (bool) – Is True if each task has its own input data, otherwise is False.

  • device (torch.device) – The device where model and data will be allocated.

  • kwargs (dict) – A dictionary of hyperparameters of architectures.

forward(self, inputs, task_name=None)[source]
Parameters
  • inputs (torch.Tensor) – The input data.

  • task_name (str, default=None) – The task name corresponding to inputs if multi_input is True.

Returns

A dictionary of name-prediction pairs of type (str, torch.Tensor).

Return type

dict

get_share_params(self)[source]

Return the shared parameters of the model.

zero_grad_share_params(self)[source]

Set gradients of the shared parameters to zero.