LibMTL.architecture.MTAN
¶
- class MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitecture
Multi-Task Attention Network (MTAN).
This method is proposed in End-To-End Multi-Task Learning With Attention (CVPR 2019) and implemented by modifying from the official PyTorch implementation.
Warning
MTAN
is only supported by ResNet-based encoders.- forward(self, inputs, task_name=None)[source]¶
- Parameters
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputs
ifmulti_input
isTrue
.
- Returns
A dictionary of name-prediction pairs of type (
str
,torch.Tensor
).- Return type
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.