Customize a Weighting Strategy¶
Here we introduce how to customize a new weighting strategy with the support of LibMTL.
Create a New Weighting Class¶
Firstly, you need to create a new weighting class by inheriting class LibMTL.weighting.AbsWeighting.
from LibMTL.weighting import AbsWeighting
class NewWeighting(AbsWeighting):
def __init__(self):
super(NewWeighting, self).__init__()
Rewrite Relevant Methods¶
There are four important functions in LibMTL.weighting.AbsWeighting.
backward(): It is the main function of a weighting strategy whose input and output formats can be found inLibMTL.weighting.AbsWeighting.backward(). To rewrite this function, you need to consider the case ofsingle-inputandmulti-input(refer to here) and the case ofrep-gradandparam-grad(refer to here) if you want to combine your weighting method with more architectures or apply your method to more datasets.init_param(): This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary._get_grads(): This function is used to return the gradients of representations or shared parameters (corresponding to the case ofrep-gradandparam-grad, respectively)._backward_new_grads(): This function is used to reset the gradients and make a backward pass (corresponding to the case ofrep-gradandparam-grad, respectively).
The _get_grads() and _backward_new_grads() functions are very useful to rewrite the backward() function and you can find more details here.