import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from LibMTL.architecture.abstract_arch import AbsArchitecture
class _transform_resnet_MTAN(nn.Module):
def __init__(self, resnet_network, task_name, device):
super(_transform_resnet_MTAN, self).__init__()
self.task_name = task_name
self.task_num = len(task_name)
self.device = device
self.forward_task = None
self.expansion = 4 if resnet_network.feature_dim == 2048 else 1
ch = np.array([64, 128, 256, 512]) * self.expansion
self.shared_conv = nn.Sequential(resnet_network.conv1, resnet_network.bn1,
resnet_network.relu, resnet_network.maxpool)
self.shared_layer, self.encoder_att, self.encoder_block_att = nn.ModuleDict({}), nn.ModuleDict({}), nn.ModuleList([])
for i in range(4):
self.shared_layer[str(i)] = nn.ModuleList([eval('resnet_network.layer'+str(i+1)+'[:-1]'),
eval('resnet_network.layer'+str(i+1)+'[-1]')])
if i == 0:
self.encoder_att[str(i)] = nn.ModuleList([self._att_layer(ch[0],
ch[0]//self.expansion,
ch[0]).to(self.device) for _ in range(self.task_num)])
else:
self.encoder_att[str(i)] = nn.ModuleList([self._att_layer(2*ch[i],
ch[i]//self.expansion,
ch[i]).to(self.device) for _ in range(self.task_num)])
if i < 3:
self.encoder_block_att.append(self._conv_layer(ch[i], ch[i+1]//self.expansion).to(self.device))
self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2)
def _att_layer(self, in_channel, intermediate_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channels=in_channel, out_channels=intermediate_channel, kernel_size=1, padding=0),
nn.BatchNorm2d(intermediate_channel),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=intermediate_channel, out_channels=out_channel, kernel_size=1, padding=0),
nn.BatchNorm2d(out_channel),
nn.Sigmoid())
def _conv_layer(self, in_channel, out_channel):
from LibMTL.model.resnet import conv1x1
downsample = nn.Sequential(conv1x1(in_channel, self.expansion * out_channel, stride=1),
nn.BatchNorm2d(self.expansion * out_channel))
if self.expansion == 4:
from LibMTL.model.resnet import Bottleneck
return Bottleneck(in_channel, out_channel, downsample=downsample)
else:
from LibMTL.model.resnet import BasicBlock
return BasicBlock(in_channel, out_channel, downsample=downsample)
def forward(self, inputs):
s_rep = self.shared_conv(inputs)
ss_rep = {i: [0]*2 for i in range(4)}
att_rep = [0]*self.task_num
for i in range(4):
for j in range(2):
if i == 0 and j == 0:
sh_rep = s_rep
elif i != 0 and j == 0:
sh_rep = ss_rep[i-1][1]
else:
sh_rep = ss_rep[i][0]
ss_rep[i][j] = self.shared_layer[str(i)][j](sh_rep)
for tn, task in enumerate(self.task_name):
if self.forward_task is not None and task != self.forward_task:
continue
if i == 0:
att_mask = self.encoder_att[str(i)][tn](ss_rep[i][0])
else:
if ss_rep[i][0].size()[-2:] != att_rep[tn].size()[-2:]:
att_rep[tn] = self.down_sampling(att_rep[tn])
att_mask = self.encoder_att[str(i)][tn](torch.cat([ss_rep[i][0], att_rep[tn]], dim=1))
att_rep[tn] = att_mask * ss_rep[i][1]
if i < 3:
att_rep[tn] = self.encoder_block_att[i](att_rep[tn])
if i == 0:
att_rep[tn] = self.down_sampling(att_rep[tn])
if self.forward_task is None:
return att_rep
else:
return att_rep[self.task_name.index(self.forward_task)]
[docs]class MTAN(AbsArchitecture):
r"""Multi-Task Attention Network (MTAN).
This method is proposed in `End-To-End Multi-Task Learning With Attention (CVPR 2019) <https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf>`_ \
and implemented by modifying from the `official PyTorch implementation <https://github.com/lorenmt/mtan>`_.
.. warning::
:class:`MTAN` is only supported by ResNet-based encoders.
"""
def __init__(self, task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs):
super(MTAN, self).__init__(task_name, encoder_class, decoders, rep_grad, multi_input, device, **kwargs)
self.encoder = self.encoder_class()
try:
callable(eval('self.encoder.layer1'))
self.encoder = _transform_resnet_MTAN(self.encoder.to(device), task_name, device)
except:
self.encoder.resnet_network = _transform_resnet_MTAN(self.encoder.resnet_network.to(device), task_name, device)
[docs] def forward(self, inputs, task_name=None):
out = {}
if self.multi_input:
try:
callable(eval('self.encoder.resnet_network'))
self.encoder.resnet_network.forward_task = task_name
except:
self.encoder.forward_task = task_name
s_rep = self.encoder(inputs)
for tn, task in enumerate(self.task_name):
if task_name is not None and task != task_name:
continue
ss_rep = s_rep[tn] if isinstance(s_rep, list) else s_rep
ss_rep = self._prepare_rep(ss_rep, task, same_rep=False)
out[task] = self.decoders[task](ss_rep)
return out
[docs] def get_share_params(self):
try:
callable(eval('self.encoder.resnet_network'))
r = self.encoder.resnet_network
except:
r = self.encoder
p = []
p += r.shared_conv.parameters()
p += r.shared_layer.parameters()
if r != self.encoder:
for n, param in self.encoder.named_parameters():
if 'resnet_network' not in n:
p.append(param)
return p
[docs] def zero_grad_share_params(self):
try:
callable(eval('self.encoder.resnet_network'))
r = self.encoder.resnet_network
except:
r = self.encoder
for n, m in self.encoder.named_modules():
if 'resnet_network' not in n:
m.zero_grad()
r.shared_conv.zero_grad()
r.shared_layer.zero_grad()