import torch.nn as nn
import LibMTL.model.resnet as resnet
[docs]class ResnetDilated(nn.Module):
def __init__(self, orig_resnet, dilate_scale=8):
super(ResnetDilated, self).__init__()
from functools import partial
if dilate_scale == 8:
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
elif dilate_scale == 16:
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
# take pre-defined ResNet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
self.bn1 = orig_resnet.bn1
self.relu = orig_resnet.relu
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
self.feature_dim = orig_resnet.feature_dim
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
# the convolution with stride
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate//2, dilate//2)
m.padding = (dilate//2, dilate//2)
# other convoluions
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
[docs] def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
[docs]def resnet_dilated(basenet, pretrained=True, dilate_scale=8):
r"""Dilated Residual Network models from `"Dilated Residual Networks" <https://openaccess.thecvf.com/content_cvpr_2017/papers/Yu_Dilated_Residual_Networks_CVPR_2017_paper.pdf>`_
Args:
basenet (str): The type of ResNet.
pretrained (bool): If True, returns a model pre-trained on ImageNet.
dilate_scale ({8, 16}, default=8): The type of dilating process.
"""
return ResnetDilated(resnet.__dict__[basenet](pretrained=pretrained), dilate_scale=dilate_scale)