Customize a Weighting Strategy

Here we introduce how to customize a new weighting strategy with the support of LibMTL.

Create a New Weighting Class

Firstly, you need to create a new weighting class by inheriting class LibMTL.weighting.AbsWeighting.

from LibMTL.weighting import AbsWeighting

class NewWeighting(AbsWeighting):
    def __init__(self):
        super(NewWeighting, self).__init__()

Rewrite Relevant Methods

There are four important functions in LibMTL.weighting.AbsWeighting.

  • backward(): It is the main function of a weighting strategy whose input and output formats can be found in LibMTL.weighting.AbsWeighting.backward(). To rewrite this function, you need to consider the case of single-input and multi-input (refer to here) and the case of rep-grad and param-grad (refer to here) if you want to combine your weighting method with more architectures or apply your method to more datasets.

  • init_param(): This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary.

  • _get_grads(): This function is used to return the gradients of representations or shared parameters (corresponding to the case of rep-grad and param-grad, respectively).

  • _backward_new_grads(): This function is used to reset the gradients and make a backward pass (corresponding to the case of rep-grad and param-grad, respectively).

The _get_grads() and _backward_new_grads() functions are very useful to rewrite the backward() function and you can find more details here.