Customize a Weighting Strategy¶
Here we introduce how to customize a new weighting strategy with the support of LibMTL
.
Create a New Weighting Class¶
Firstly, you need to create a new weighting class by inheriting class LibMTL.weighting.AbsWeighting
.
from LibMTL.weighting import AbsWeighting
class NewWeighting(AbsWeighting):
def __init__(self):
super(NewWeighting, self).__init__()
Rewrite Relevant Methods¶
There are four important functions in LibMTL.weighting.AbsWeighting
.
backward()
: It is the main function of a weighting strategy whose input and output formats can be found inLibMTL.weighting.AbsWeighting.backward()
. To rewrite this function, you need to consider the case ofsingle-input
andmulti-input
(refer to here) and the case ofrep-grad
andparam-grad
(refer to here) if you want to combine your weighting method with more architectures or apply your method to more datasets.init_param()
: This function is used to define and initialize some trainable parameters. It does nothing by default and can be rewritten if necessary._get_grads()
: This function is used to return the gradients of representations or shared parameters (corresponding to the case ofrep-grad
andparam-grad
, respectively)._backward_new_grads()
: This function is used to reset the gradients and make a backward pass (corresponding to the case ofrep-grad
andparam-grad
, respectively).
The _get_grads()
and _backward_new_grads()
functions are very useful to rewrite the backward()
function and you can find more details here.