Customize an Architecture

Here we introduce how to customize a new architecture with the support of LibMTL.

Create a New Architecture Class

Firstly, you need to create a new architecture class by inheriting class LibMTL.architecture.AbsArchitecture.

from LibMTL.architecture import AbsArchitecture

class NewArchitecture(AbsArchitecture):
    def __init__(self, task_name, encoder_class, decoders, rep_grad, 
                       multi_input, device, **kwargs):
        super(NewArchitecture, self).__init__(task_name, encoder_class, decoders, rep_grad, 
                                  			  multi_input, device, **kwargs)

Rewrite Relevant Methods

There are four important functions in LibMTL.architecture.AbsArchitecture.

  • forward(): The forward function and its input/output format can be found in LibMTL.architecture.AbsArchitecture.forward(). To rewrite this function, you need to consider the case of single-input and multi-input (refer to here) and the case of rep-grad and param-grad (refer to here) if you want to combine your architecture with more weighting strategies or apply your architecture to more datasets.

  • get_share_params(): This function is used to return the shared parameters of the model. It returns all the parameters of the encoder by default. You can rewrite it if necessary.

  • zero_grad_share_params(): This function is used to set gradients of the shared parameters to zero. It will set the gradients of all the encoder parameters to zero by default. You can rewrite it if necessary.

  • _prepare_rep(): This function is used to compute the gradients for representations. More details can be found here.