LibMTL.architecture.PLE

class PLE(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)[source]

Bases: LibMTL.architecture.abstract_arch.AbsArchitecture

Progressive Layered Extraction (PLE).

This method is proposed in Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations (ACM RecSys 2020 Best Paper) and implemented by us.

Parameters
  • img_size (list) – The size of input data. For example, [3, 244, 244] denotes input images with size 3x224x224.

  • num_experts (list) – The numbers of experts shared by all the tasks and specific to each task, respectively. Each expert is an encoder network.

Warning

  • PLE does not work with multi-input problems, i.e., multi_input must be False.

  • PLE is only supported by ResNet-based encoders.

forward(self, inputs, task_name=None)[source]
Parameters
  • inputs (torch.Tensor) – The input data.

  • task_name (str, default=None) – The task name corresponding to inputs if multi_input is True.

Returns

A dictionary of name-prediction pairs of type (str, torch.Tensor).

Return type

dict

get_share_params(self)[source]

Return the shared parameters of the model.

zero_grad_share_params(self)[source]

Set gradients of the shared parameters to zero.