LibMTL.architecture.MTAN¶
- class MTAN(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]¶
Bases:
LibMTL.architecture.abstract_arch.AbsArchitectureMulti-Task Attention Network (MTAN).
This method is proposed in End-To-End Multi-Task Learning With Attention (CVPR 2019) and implemented by modifying from the official PyTorch implementation.
Warning
MTANis only supported by ResNet-based encoders.- forward(inputs, task_name=None)[source]¶
- Parameters:
inputs (torch.Tensor) – The input data.
task_name (str, default=None) – The task name corresponding to
inputsifmulti_inputisTrue.
- Returns:
A dictionary of name-prediction pairs of type (
str,torch.Tensor).- Return type:
dict
Return the shared parameters of the model.
Set gradients of the shared parameters to zero.