diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index dd9cea2c5c0..75291814f40 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,6 +1,7 @@ import torch.nn as nn import math import torch.utils.model_zoo as model_zoo +import numpy as np __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -16,21 +17,35 @@ } -def conv3x3(in_planes, out_planes, stride=1): +def conv3x3(in_planes, out_planes, stride=1, dilation=1): "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) + + kernel_size = np.asarray((3, 3)) + + # Compute the size of the upsampled filter with + # a specified dilation rate. + upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size + + # Determine the padding that is necessary for full padding, + # meaning the output spatial size is equal to input spatial size + full_padding = (upsampled_kernel_size - 1) // 2 + + # Conv2d doesn't accept numpy arrays as arguments + full_padding, kernel_size = tuple(full_padding), tuple(kernel_size) + + return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=full_padding, dilation=dilation, bias=False) class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): super(BasicBlock, self).__init__() - self.conv1 = conv3x3(inplanes, planes, stride) + self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) + self.conv2 = conv3x3(planes, planes, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride @@ -57,12 +72,16 @@ def forward(self, x): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) + + self.conv2 = conv3x3(planes, planes, stride=stride, dilation=dilation) + + #self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + # padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -95,20 +114,43 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): + def __init__(self, + block, + layers, + num_classes=1000, + fully_conv=False, + remove_avg_pool_layer=False, + output_stride=32): + + # Add additional variables to track + # output stride. Necessary to achieve + # specified output stride. + self.output_stride = output_stride + self.current_stride = 4 + self.current_dilation = 1 + + self.remove_avg_pool_layer = remove_avg_pool_layer + self.inplanes = 64 + self.fully_conv = fully_conv super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) self.fc = nn.Linear(512 * block.expansion, num_classes) + + if self.fully_conv: + self.avgpool = nn.AvgPool2d(7, padding=3, stride=1) + self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -118,9 +160,27 @@ def __init__(self, block, layers, num_classes=1000): m.weight.data.fill_(1) m.bias.data.zero_() - def _make_layer(self, block, planes, blocks, stride=1): + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + + + # Check if we already achieved desired output stride. + if self.current_stride == self.output_stride: + + # If so, replace subsampling with a dilation to preserve + # current spatial resolution. + self.current_dilation = self.current_dilation * stride + stride = 1 + else: + + # If not, perform subsampling and update current + # new output stride. + self.current_stride = self.current_stride * stride + + + # We don't dilate 1x1 convolution. downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), @@ -128,10 +188,10 @@ def _make_layer(self, block, planes, blocks, stride=1): ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) + layers.append(block(self.inplanes, planes, stride, downsample, dilation=self.current_dilation)) self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) + layers.append(block(self.inplanes, planes, dilation=self.current_dilation)) return nn.Sequential(*layers) @@ -145,9 +205,13 @@ def forward(self, x): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) - - x = self.avgpool(x) - x = x.view(x.size(0), -1) + + if not self.remove_avg_pool_layer: + x = self.avgpool(x) + + if not self.fully_conv: + x = x.view(x.size(0), -1) + x = self.fc(x) return x