LibMTL.architecture.MTAN

class MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]

Bases: LibMTL.architecture.abstract_arch.AbsArchitecture

Multi-Task Attention Network (MTAN).

This method is proposed in End-To-End Multi-Task Learning With Attention (CVPR 2019) and implemented by modifying from the official PyTorch implementation.

Warning

MTAN is only supported by ResNet-based encoders.

forward(self, inputs, task_name=None)[source]
Parameters
  • inputs (torch.Tensor) – The input data.

  • task_name (str, default=None) – The task name corresponding to inputs if multi_input is True.

Returns

A dictionary of name-prediction pairs of type (str, torch.Tensor).

Return type

dict

get_share_params(self)[source]

Return the shared parameters of the model.

zero_grad_share_params(self)[source]

Set gradients of the shared parameters to zero.