LibMTL.architecture.LTB

class LTB(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]

Bases: LibMTL.architecture.abstract_arch.AbsArchitecture

Learning To Branch (LTB).

This method is proposed in Learning to Branch for Multi-Task Learning (ICML 2020) and implemented by us.

Warning

  • LTB does not work with multi-input problems, i.e., multi_input must be False.

  • LTB is only supported by ResNet-based encoders.

forward(self, inputs, task_name=None)[source]
Parameters
  • inputs (torch.Tensor) – The input data.

  • task_name (str, default=None) – The task name corresponding to inputs if multi_input is True.

Returns

A dictionary of name-prediction pairs of type (str, torch.Tensor).

Return type

dict